Commit 3233a41d authored by muyangli's avatar muyangli
Browse files

[minor] fix sana

parent e0fadc93
...@@ -23,6 +23,5 @@ image = pipe( ...@@ -23,6 +23,5 @@ image = pipe(
guidance_scale=5.0, guidance_scale=5.0,
pag_scale=2.0, pag_scale=2.0,
num_inference_steps=20, num_inference_steps=20,
generator=torch.Generator().manual_seed(42),
).images[0] ).images[0]
image.save("sana_1600m_pag.png") image.save("sana_1600m_pag.png")
__version__ = "0.1.1" __version__ = "0.1.2"
...@@ -124,9 +124,13 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader ...@@ -124,9 +124,13 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
device = kwargs.get("device", "cuda") device = kwargs.get("device", "cuda")
pag_layers = kwargs.get("pag_layers", []) pag_layers = kwargs.get("pag_layers", [])
precision = kwargs.get("precision", "int4")
assert precision in ["int4", "fp4"]
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs) transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
transformer.config["num_layers"] = transformer.original_num_layers transformer.config["num_layers"] = transformer.original_num_layers
m = load_quantized_module(transformer, transformer_block_path, device=device, pag_layers=pag_layers) m = load_quantized_module(
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device) transformer.inject_quantized_module(m, device)
return transformer return transformer
...@@ -140,6 +144,7 @@ def load_quantized_module( ...@@ -140,6 +144,7 @@ def load_quantized_module(
path: str, path: str,
device: str | torch.device = "cuda", device: str | torch.device = "cuda",
pag_layers: int | list[int] | None = None, pag_layers: int | list[int] | None = None,
use_fp4: bool = False,
) -> QuantizedSanaModel: ) -> QuantizedSanaModel:
if pag_layers is None: if pag_layers is None:
pag_layers = [] pag_layers = []
...@@ -150,7 +155,7 @@ def load_quantized_module( ...@@ -150,7 +155,7 @@ def load_quantized_module(
m = QuantizedSanaModel() m = QuantizedSanaModel()
cutils.disable_memory_auto_release() cutils.disable_memory_auto_release()
m.init(net.config, pag_layers, net.dtype == torch.bfloat16, 0 if device.index is None else device.index) m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
m.load(path) m.load(path)
return m return m
......
...@@ -111,10 +111,12 @@ if __name__ == "__main__": ...@@ -111,10 +111,12 @@ if __name__ == "__main__":
"--threads=3", "--threads=3",
"--expt-relaxed-constexpr", "--expt-relaxed-constexpr",
"--expt-extended-lambda", "--expt-extended-lambda",
"--generate-line-info",
"--ptxas-options=--allow-expensive-optimizations=true", "--ptxas-options=--allow-expensive-optimizations=true",
] ]
if os.getenv("NUNCHAKU_BUILD_WHEELS", "0") == "0":
NVCC_FLAGS.append("--generate-line-info")
sm_targets = get_sm_targets() sm_targets = get_sm_targets()
print(f"Detected SM targets: {sm_targets}", file=sys.stderr) print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment