Unverified Commit 772fdb14 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

update config (#379)

parent 4a9f0df5
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
"cpu_offload": false, "cpu_offload": false,
"use_31_block": false, "use_31_block": false,
"clip_quantized": true, "clip_quantized": true,
"clip_quant_scheme": "fp8", "clip_quant_scheme": "fp8-sgl",
"dit_quantized": true, "dit_quantized": true,
"dit_quant_scheme": "fp8-sgl", "dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true, "t5_quantized": true,
"t5_quant_scheme": "fp8" "t5_quant_scheme": "fp8-sgl"
} }
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
"dit_quantized": true, "dit_quantized": true,
"dit_quant_scheme": "fp8-sgl", "dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true, "t5_quantized": true,
"t5_quant_scheme": "fp8" "t5_quant_scheme": "fp8-sgl"
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
"audio_encoder_cpu_offload": false, "audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false, "audio_adapter_cpu_offload": false,
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8-q8f",
"vae_cpu_offload": false, "vae_cpu_offload": false,
"use_tiling_vae": false, "use_tiling_vae": false,
"dit_quantized": true, "dit_quantized": true,
......
...@@ -18,12 +18,12 @@ ...@@ -18,12 +18,12 @@
"t5_cpu_offload": true, "t5_cpu_offload": true,
"t5_offload_granularity": "model", "t5_offload_granularity": "model",
"t5_quantized": true, "t5_quantized": true,
"t5_quant_scheme": "fp8", "t5_quant_scheme": "fp8-sgl",
"clip_cpu_offload": false, "clip_cpu_offload": false,
"audio_encoder_cpu_offload": false, "audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false, "audio_adapter_cpu_offload": false,
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8-sgl",
"vae_cpu_offload": false, "vae_cpu_offload": false,
"use_tiling_vae": false, "use_tiling_vae": false,
"dit_quantized": true, "dit_quantized": true,
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
"dit_quantized": true, "dit_quantized": true,
"dit_quant_scheme": "fp8-sgl", "dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true, "t5_quantized": true,
"t5_quant_scheme": "fp8" "t5_quant_scheme": "fp8-sgl"
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
"dit_quantized": true, "dit_quantized": true,
"dit_quant_scheme": "fp8-sgl", "dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true, "t5_quantized": true,
"t5_quant_scheme": "fp8" "t5_quant_scheme": "fp8-sgl"
} }
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
"dit_quantized": true, "dit_quantized": true,
"dit_quant_scheme": "fp8-sgl", "dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true, "t5_quantized": true,
"t5_quant_scheme": "fp8" "t5_quant_scheme": "fp8-sgl"
} }
...@@ -56,7 +56,7 @@ ...@@ -56,7 +56,7 @@
"dit_quantized": true, "dit_quantized": true,
"dit_quant_scheme": "fp8-sgl", "dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true, "t5_quantized": true,
"t5_quant_scheme": "fp8" "t5_quant_scheme": "fp8-sgl"
} }
...@@ -57,7 +57,7 @@ ...@@ -57,7 +57,7 @@
"dit_quantized": true, "dit_quantized": true,
"dit_quant_scheme": "fp8-sgl", "dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true, "t5_quantized": true,
"t5_quant_scheme": "fp8" "t5_quant_scheme": "fp8-sgl"
} }
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
"dit_quantized": true, "dit_quantized": true,
"dit_quant_scheme": "fp8-sgl", "dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true, "t5_quantized": true,
"t5_quant_scheme": "fp8", "t5_quant_scheme": "fp8-sgl",
"compile": true, "compile": true,
"compile_shapes": [ "compile_shapes": [
[ [
......
...@@ -84,9 +84,9 @@ class T5Attention(nn.Module): ...@@ -84,9 +84,9 @@ class T5Attention(nn.Module):
self.head_dim = dim_attn // num_heads self.head_dim = dim_attn // num_heads
if quantized: if quantized:
if quant_scheme == "int8": if quant_scheme in ["int8", "int8-vllm"]:
linear_cls = VllmQuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8 linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao": elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8 linear_cls = TorchaoQuantLinearInt8
...@@ -94,6 +94,8 @@ class T5Attention(nn.Module): ...@@ -94,6 +94,8 @@ class T5Attention(nn.Module):
linear_cls = Q8FQuantLinearInt8 linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f": elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8 linear_cls = Q8FQuantLinearFp8
else:
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
...@@ -151,9 +153,9 @@ class T5FeedForward(nn.Module): ...@@ -151,9 +153,9 @@ class T5FeedForward(nn.Module):
self.dim_ffn = dim_ffn self.dim_ffn = dim_ffn
if quantized: if quantized:
if quant_scheme == "int8": if quant_scheme in ["int8", "int8-vllm"]:
linear_cls = VllmQuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8 linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao": elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8 linear_cls = TorchaoQuantLinearInt8
...@@ -161,6 +163,8 @@ class T5FeedForward(nn.Module): ...@@ -161,6 +163,8 @@ class T5FeedForward(nn.Module):
linear_cls = Q8FQuantLinearInt8 linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f": elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8 linear_cls = Q8FQuantLinearFp8
else:
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
# layers # layers
......
...@@ -59,9 +59,9 @@ class SelfAttention(nn.Module): ...@@ -59,9 +59,9 @@ class SelfAttention(nn.Module):
# layers # layers
if quantized: if quantized:
if quant_scheme == "int8": if quant_scheme in ["int8", "int8-vllm"]:
linear_cls = VllmQuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8 linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao": elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8 linear_cls = TorchaoQuantLinearInt8
...@@ -69,6 +69,8 @@ class SelfAttention(nn.Module): ...@@ -69,6 +69,8 @@ class SelfAttention(nn.Module):
linear_cls = Q8FQuantLinearInt8 linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f": elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8 linear_cls = Q8FQuantLinearFp8
else:
NotImplementedError(f"Unsupported CLip quant scheme: {quant_scheme}")
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
...@@ -137,9 +139,9 @@ class AttentionBlock(nn.Module): ...@@ -137,9 +139,9 @@ class AttentionBlock(nn.Module):
# layers # layers
if quantized: if quantized:
if quant_scheme == "int8": if quant_scheme in ["int8", "int8-vllm"]:
linear_cls = VllmQuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8 linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao": elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8 linear_cls = TorchaoQuantLinearInt8
...@@ -147,6 +149,8 @@ class AttentionBlock(nn.Module): ...@@ -147,6 +149,8 @@ class AttentionBlock(nn.Module):
linear_cls = Q8FQuantLinearInt8 linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f": elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8 linear_cls = Q8FQuantLinearFp8
else:
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
......
...@@ -28,9 +28,9 @@ class WanAudioModel(WanModel): ...@@ -28,9 +28,9 @@ class WanAudioModel(WanModel):
def _load_adapter_ckpt(self): def _load_adapter_ckpt(self):
if self.config.get("adapter_model_path", None) is None: if self.config.get("adapter_model_path", None) is None:
if self.config.get("adapter_quantized", False): if self.config.get("adapter_quantized", False):
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f"]: if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f", "fp8-vllm", "fp8-sgl"]:
adapter_model_name = "audio_adapter_model_fp8.safetensors" adapter_model_name = "audio_adapter_model_fp8.safetensors"
elif self.config.get("adapter_quant_scheme", None) == "int8": elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-sgl"]:
adapter_model_name = "audio_adapter_model_int8.safetensors" adapter_model_name = "audio_adapter_model_int8.safetensors"
else: else:
raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}") raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment