Unverified Commit 772fdb14 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

update config (#379)

parent 4a9f0df5
......@@ -14,11 +14,11 @@
"cpu_offload": false,
"use_31_block": false,
"clip_quantized": true,
"clip_quant_scheme": "fp8",
"clip_quant_scheme": "fp8-sgl",
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true,
"t5_quant_scheme": "fp8"
"t5_quant_scheme": "fp8-sgl"
}
......@@ -20,7 +20,7 @@
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true,
"t5_quant_scheme": "fp8"
"t5_quant_scheme": "fp8-sgl"
}
......@@ -24,7 +24,7 @@
"audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false,
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"adapter_quant_scheme": "fp8-q8f",
"vae_cpu_offload": false,
"use_tiling_vae": false,
"dit_quantized": true,
......
......@@ -18,12 +18,12 @@
"t5_cpu_offload": true,
"t5_offload_granularity": "model",
"t5_quantized": true,
"t5_quant_scheme": "fp8",
"t5_quant_scheme": "fp8-sgl",
"clip_cpu_offload": false,
"audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false,
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"adapter_quant_scheme": "fp8-sgl",
"vae_cpu_offload": false,
"use_tiling_vae": false,
"dit_quantized": true,
......
......@@ -20,7 +20,7 @@
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true,
"t5_quant_scheme": "fp8"
"t5_quant_scheme": "fp8-sgl"
}
......@@ -24,7 +24,7 @@
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true,
"t5_quant_scheme": "fp8"
"t5_quant_scheme": "fp8-sgl"
}
......@@ -25,7 +25,7 @@
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true,
"t5_quant_scheme": "fp8"
"t5_quant_scheme": "fp8-sgl"
}
......@@ -56,7 +56,7 @@
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true,
"t5_quant_scheme": "fp8"
"t5_quant_scheme": "fp8-sgl"
}
......@@ -57,7 +57,7 @@
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true,
"t5_quant_scheme": "fp8"
"t5_quant_scheme": "fp8-sgl"
}
......@@ -20,9 +20,9 @@
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl",
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"adapter_quant_scheme": "fp8-sgl",
"t5_quantized": true,
"t5_quant_scheme": "fp8",
"t5_quant_scheme": "fp8-sgl",
"compile": true,
"compile_shapes": [
[
......
......@@ -84,9 +84,9 @@ class T5Attention(nn.Module):
self.head_dim = dim_attn // num_heads
if quantized:
if quant_scheme == "int8":
if quant_scheme in ["int8", "int8-vllm"]:
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
elif quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
......@@ -94,6 +94,8 @@ class T5Attention(nn.Module):
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else:
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
else:
linear_cls = nn.Linear
......@@ -151,9 +153,9 @@ class T5FeedForward(nn.Module):
self.dim_ffn = dim_ffn
if quantized:
if quant_scheme == "int8":
if quant_scheme in ["int8", "int8-vllm"]:
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
elif quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
......@@ -161,6 +163,8 @@ class T5FeedForward(nn.Module):
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else:
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
else:
linear_cls = nn.Linear
# layers
......
......@@ -59,9 +59,9 @@ class SelfAttention(nn.Module):
# layers
if quantized:
if quant_scheme == "int8":
if quant_scheme in ["int8", "int8-vllm"]:
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
elif quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
......@@ -69,6 +69,8 @@ class SelfAttention(nn.Module):
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else:
NotImplementedError(f"Unsupported CLip quant scheme: {quant_scheme}")
else:
linear_cls = nn.Linear
......@@ -137,9 +139,9 @@ class AttentionBlock(nn.Module):
# layers
if quantized:
if quant_scheme == "int8":
if quant_scheme in ["int8", "int8-vllm"]:
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
elif quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
......@@ -147,6 +149,8 @@ class AttentionBlock(nn.Module):
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else:
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
else:
linear_cls = nn.Linear
......
......@@ -28,9 +28,9 @@ class WanAudioModel(WanModel):
def _load_adapter_ckpt(self):
if self.config.get("adapter_model_path", None) is None:
if self.config.get("adapter_quantized", False):
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f"]:
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f", "fp8-vllm", "fp8-sgl"]:
adapter_model_name = "audio_adapter_model_fp8.safetensors"
elif self.config.get("adapter_quant_scheme", None) == "int8":
elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-sgl"]:
adapter_model_name = "audio_adapter_model_int8.safetensors"
else:
raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment