Commit 3233a41d authored by muyangli's avatar muyangli
Browse files

[minor] fix sana

parent e0fadc93
......@@ -23,6 +23,5 @@ image = pipe(
guidance_scale=5.0,
pag_scale=2.0,
num_inference_steps=20,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m_pag.png")
__version__ = "0.1.1"
__version__ = "0.1.2"
......@@ -124,9 +124,13 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
device = kwargs.get("device", "cuda")
pag_layers = kwargs.get("pag_layers", [])
precision = kwargs.get("precision", "int4")
assert precision in ["int4", "fp4"]
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
transformer.config["num_layers"] = transformer.original_num_layers
m = load_quantized_module(transformer, transformer_block_path, device=device, pag_layers=pag_layers)
m = load_quantized_module(
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
return transformer
......@@ -140,6 +144,7 @@ def load_quantized_module(
path: str,
device: str | torch.device = "cuda",
pag_layers: int | list[int] | None = None,
use_fp4: bool = False,
) -> QuantizedSanaModel:
if pag_layers is None:
pag_layers = []
......@@ -150,7 +155,7 @@ def load_quantized_module(
m = QuantizedSanaModel()
cutils.disable_memory_auto_release()
m.init(net.config, pag_layers, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
m.load(path)
return m
......
......@@ -111,10 +111,12 @@ if __name__ == "__main__":
"--threads=3",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--generate-line-info",
"--ptxas-options=--allow-expensive-optimizations=true",
]
if os.getenv("NUNCHAKU_BUILD_WHEELS", "0") == "0":
NVCC_FLAGS.append("--generate-line-info")
sm_targets = get_sm_targets()
print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment