Unverified Commit 092e01ec authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

fix: close the NVFP4 performance gap between the Python backend and C backend


Co-authored-by: default avatarKung Talon <31659820+kungtalon@users.noreply.github.com>
parent 7fcce6f3
...@@ -10,7 +10,12 @@ from utils import get_pipeline ...@@ -10,7 +10,12 @@ from utils import get_pipeline
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"-m", "--model", type=str, default="schnell", choices=["schnell", "dev"], help="Which FLUX.1 model to use" "-m",
"--model",
type=str,
default="schnell",
choices=["schnell", "schnell_v2", "dev"],
help="Which FLUX.1 model to use",
) )
parser.add_argument( parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precision to use" "-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precision to use"
......
...@@ -30,19 +30,20 @@ def get_pipeline( ...@@ -30,19 +30,20 @@ def get_pipeline(
assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices" assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
if precision == "int4": if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors" "nunchaku-tech/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors"
) )
else: else:
assert precision == "fp4" assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-schnell/svdq-fp4_r32-flux.1-schnell.safetensors", precision="fp4" "nunchaku-tech/nunchaku-flux.1-schnell/svdq-fp4_r32-flux.1-schnell.safetensors", precision="fp4"
) )
transformer.set_attention_impl("nunchaku-fp16")
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if use_qencoder: if use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained( text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors" "nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
) )
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
else: else:
...@@ -52,7 +53,7 @@ def get_pipeline( ...@@ -52,7 +53,7 @@ def get_pipeline(
) )
elif model_name == "schnell_v2": elif model_name == "schnell_v2":
transformer = NunchakuFluxTransformer2DModelV2.from_pretrained( transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors" f"nunchaku-tech/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors"
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-schnell",
...@@ -63,7 +64,7 @@ def get_pipeline( ...@@ -63,7 +64,7 @@ def get_pipeline(
elif model_name == "dev": elif model_name == "dev":
if precision == "int4": if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors" "nunchaku-tech/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors"
) )
if lora_name not in ["All", "None"]: if lora_name not in ["All", "None"]:
transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name]) transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
...@@ -73,7 +74,7 @@ def get_pipeline( ...@@ -73,7 +74,7 @@ def get_pipeline(
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained( text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors" "nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
) )
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
......
...@@ -64,13 +64,14 @@ class SVDQW4A4Linear(nn.Module): ...@@ -64,13 +64,14 @@ class SVDQW4A4Linear(nn.Module):
self.proj_down = nn.Parameter(torch.empty(in_features, rank, dtype=torch_dtype, device=device)) self.proj_down = nn.Parameter(torch.empty(in_features, rank, dtype=torch_dtype, device=device))
self.proj_up = nn.Parameter(torch.empty(out_features, rank, dtype=torch_dtype, device=device)) self.proj_up = nn.Parameter(torch.empty(out_features, rank, dtype=torch_dtype, device=device))
self.wtscale = None
self.wcscales = None
if precision == "nvfp4": if precision == "nvfp4":
self.wtscale = nn.Parameter(torch.ones(1, dtype=torch_dtype, device=device), requires_grad=False)
self.wcscales = nn.Parameter( self.wcscales = nn.Parameter(
torch.ones(out_features, dtype=torch_dtype, device=device), requires_grad=False torch.ones(out_features, dtype=torch_dtype, device=device), requires_grad=False
) )
self.wtscale = 1.0
else:
self.wtscale = None
self.wcscales = None
self.act_unsigned = act_unsigned self.act_unsigned = act_unsigned
......
...@@ -26,7 +26,7 @@ from .utils import NunchakuModelLoaderMixin, pad_tensor ...@@ -26,7 +26,7 @@ from .utils import NunchakuModelLoaderMixin, pad_tensor
class NunchakuFluxAttention(NunchakuBaseAttention): class NunchakuFluxAttention(NunchakuBaseAttention):
def __init__(self, other: FluxAttention, processor: str = "flashattn2", **kwargs): def __init__(self, other: FluxAttention, processor: str = "nunchaku-fp16", **kwargs):
super(NunchakuFluxAttention, self).__init__(processor) super(NunchakuFluxAttention, self).__init__(processor)
self.head_dim = other.head_dim self.head_dim = other.head_dim
...@@ -263,11 +263,17 @@ class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoad ...@@ -263,11 +263,17 @@ class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoad
for k in state_dict.keys(): for k in state_dict.keys():
if k not in converted_state_dict: if k not in converted_state_dict:
assert ".wtscale" in k or ".wcscales" in k assert ".wcscales" in k
converted_state_dict[k] = torch.ones_like(state_dict[k]) converted_state_dict[k] = torch.ones_like(state_dict[k])
else: else:
assert state_dict[k].dtype == converted_state_dict[k].dtype assert state_dict[k].dtype == converted_state_dict[k].dtype
# load the wtscale from the converted state dict
for n, m in transformer.named_modules():
if isinstance(m, SVDQW4A4Linear):
if m.wtscale is not None:
m.wtscale = converted_state_dict.pop(f"{n}.wtscale", 1.0)
transformer.load_state_dict(converted_state_dict) transformer.load_state_dict(converted_state_dict)
return transformer return transformer
......
...@@ -248,10 +248,16 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM ...@@ -248,10 +248,16 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
state_dict = transformer.state_dict() state_dict = transformer.state_dict()
for k in state_dict.keys(): for k in state_dict.keys():
if k not in model_state_dict: if k not in model_state_dict:
assert ".wtscale" in k or ".wcscales" in k assert ".wcscales" in k
model_state_dict[k] = torch.ones_like(state_dict[k]) model_state_dict[k] = torch.ones_like(state_dict[k])
else: else:
assert state_dict[k].dtype == model_state_dict[k].dtype assert state_dict[k].dtype == model_state_dict[k].dtype
# load the wtscale from the state dict, as it is a float on CPU
for n, m in transformer.named_modules():
if isinstance(m, SVDQW4A4Linear):
if m.wtscale is not None:
m.wtscale = model_state_dict.pop(f"{n}.wtscale", 1.0)
transformer.load_state_dict(model_state_dict) transformer.load_state_dict(model_state_dict)
return transformer return transformer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment