Commit 8ae9e71d authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #95 from ModelTC/dev_FIX

Fix bugs
parents b4496e64 01036b01
...@@ -111,6 +111,12 @@ def is_fp8_supported_gpu(): ...@@ -111,6 +111,12 @@ def is_fp8_supported_gpu():
return (major == 8 and minor == 9) or (major >= 9) return (major == 8 and minor == 9) or (major >= 9)
def update_precision_mode(dit_quant_scheme):
if dit_quant_scheme != "bf16":
return "bf16"
return "fp32"
global_runner = None global_runner = None
current_config = None current_config = None
...@@ -261,19 +267,15 @@ def run_inference( ...@@ -261,19 +267,15 @@ def run_inference(
is_dit_quant = dit_quant_scheme != "bf16" is_dit_quant = dit_quant_scheme != "bf16"
is_t5_quant = t5_quant_scheme != "bf16" is_t5_quant = t5_quant_scheme != "bf16"
if is_t5_quant: if is_t5_quant:
if t5_quant_scheme == "int8": t5_path = os.path.join(model_path, t5_quant_scheme)
t5_quant_ckpt = os.path.join(model_path, "models_t5_umt5-xxl-enc-int8.pth") t5_quant_ckpt = os.path.join(t5_path, f"models_t5_umt5-xxl-enc-{t5_quant_scheme}.pth")
else:
t5_quant_ckpt = os.path.join(model_path, "models_t5_umt5-xxl-enc-fp8.pth")
else: else:
t5_quant_ckpt = None t5_quant_ckpt = None
is_clip_quant = clip_quant_scheme != "fp16" is_clip_quant = clip_quant_scheme != "fp16"
if is_clip_quant: if is_clip_quant:
if clip_quant_scheme == "int8": clip_path = os.path.join(model_path, clip_quant_scheme)
clip_quant_ckpt = os.path.join(model_path, "clip-int8.pth") clip_quant_ckpt = os.path.join(clip_path, f"clip-{clip_quant_scheme}.pth")
else:
clip_quant_ckpt = os.path.join(model_path, "clip-fp8.pth")
else: else:
clip_quant_ckpt = None clip_quant_ckpt = None
...@@ -297,16 +299,22 @@ def run_inference( ...@@ -297,16 +299,22 @@ def run_inference(
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F" mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F"
dit_quantized_ckpt = os.path.join(model_path, dit_quant_scheme) dit_quantized_ckpt = os.path.join(model_path, dit_quant_scheme)
if os.path.exists(os.path.join(dit_quantized_ckpt, "config.json")):
with open(os.path.join(dit_quantized_ckpt, "config.json"), "r") as f:
quant_model_config = json.load(f)
else: else:
mm_type = "Default" mm_type = "Default"
dit_quantized_ckpt = None dit_quantized_ckpt = None
quant_model_config = {}
config = { config = {
"infer_steps": infer_steps, "infer_steps": infer_steps,
"target_video_length": num_frames, "target_video_length": num_frames,
"target_width": int(resolution.split("x")[0]), "target_width": int(resolution.split("x")[0]),
"target_height": int(resolution.split("x")[1]), "target_height": int(resolution.split("x")[1]),
"attention_type": attention_type, "self_attn_1_type": attention_type,
"cross_attn_1_type": attention_type,
"cross_attn_2_type": attention_type,
"seed": seed, "seed": seed,
"enable_cfg": enable_cfg, "enable_cfg": enable_cfg,
"sample_guide_scale": cfg_scale, "sample_guide_scale": cfg_scale,
...@@ -364,6 +372,7 @@ def run_inference( ...@@ -364,6 +372,7 @@ def run_inference(
config = EasyDict(config) config = EasyDict(config)
config["mode"] = "infer" config["mode"] = "infer"
config.update(model_config) config.update(model_config)
config.update(quant_model_config)
logger.info(f"Using model: {model_path}") logger.info(f"Using model: {model_path}")
logger.info(f"Inference configuration:\n{json.dumps(config, indent=4, ensure_ascii=False)}") logger.info(f"Inference configuration:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
...@@ -588,13 +597,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -588,13 +597,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
(32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}), (32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}),
( (
16, 16,
{ {"dit_quant_scheme_val": quant_type, "t5_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type, "lazy_load_val": True},
"dit_quant_scheme_val": quant_type,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"lazy_load_val": True,
"dit_quant_scheme_val": quant_type,
},
), ),
] ]
...@@ -784,7 +787,7 @@ def main(): ...@@ -784,7 +787,7 @@ def main():
elem_classes=["output-video"], elem_classes=["output-video"],
) )
infer_btn = gr.Button("Generate Video", variant="primary", size="lg") infer_btn = gr.Button("Generate Video", variant="primary", size="lg")
with gr.Tab("⚙️ Advanced Options", id=2): with gr.Tab("⚙️ Advanced Options", id=2):
with gr.Group(elem_classes="advanced-options"): with gr.Group(elem_classes="advanced-options"):
...@@ -894,10 +897,10 @@ def main(): ...@@ -894,10 +897,10 @@ def main():
info="Quantization precision for the Clip Encoder", info="Quantization precision for the Clip Encoder",
) )
precision_mode = gr.Dropdown( precision_mode = gr.Dropdown(
label="Precision Mode", label="Precision Mode for Sensitive Layers",
choices=["fp32", "bf16"], choices=["fp32", "bf16"],
value="fp32", value="fp32",
info="Select the numerical precision used for sensitive layers.", info="Select the numerical precision for critical model components like normalization and embedding layers. FP32 offers higher accuracy, while BF16 improves performance on compatible hardware.",
) )
gr.Markdown("### Variational Autoencoder (VAE)") gr.Markdown("### Variational Autoencoder (VAE)")
...@@ -960,6 +963,8 @@ def main(): ...@@ -960,6 +963,8 @@ def main():
], ],
) )
dit_quant_scheme.change(fn=update_precision_mode, inputs=[dit_quant_scheme], outputs=[precision_mode])
infer_btn.click( infer_btn.click(
fn=run_inference, fn=run_inference,
inputs=[ inputs=[
......
...@@ -112,6 +112,12 @@ def is_fp8_supported_gpu(): ...@@ -112,6 +112,12 @@ def is_fp8_supported_gpu():
return (major == 8 and minor == 9) or (major >= 9) return (major == 8 and minor == 9) or (major >= 9)
def update_precision_mode(dit_quant_scheme):
if dit_quant_scheme != "bf16":
return "bf16"
return "fp32"
global_runner = None global_runner = None
current_config = None current_config = None
...@@ -262,19 +268,15 @@ def run_inference( ...@@ -262,19 +268,15 @@ def run_inference(
is_dit_quant = dit_quant_scheme != "bf16" is_dit_quant = dit_quant_scheme != "bf16"
is_t5_quant = t5_quant_scheme != "bf16" is_t5_quant = t5_quant_scheme != "bf16"
if is_t5_quant: if is_t5_quant:
if t5_quant_scheme == "int8": t5_path = os.path.join(model_path, t5_quant_scheme)
t5_quant_ckpt = os.path.join(model_path, "models_t5_umt5-xxl-enc-int8.pth") t5_quant_ckpt = os.path.join(t5_path, f"models_t5_umt5-xxl-enc-{t5_quant_scheme}.pth")
else:
t5_quant_ckpt = os.path.join(model_path, "models_t5_umt5-xxl-enc-fp8.pth")
else: else:
t5_quant_ckpt = None t5_quant_ckpt = None
is_clip_quant = clip_quant_scheme != "fp16" is_clip_quant = clip_quant_scheme != "fp16"
if is_clip_quant: if is_clip_quant:
if clip_quant_scheme == "int8": clip_path = os.path.join(model_path, clip_quant_scheme)
clip_quant_ckpt = os.path.join(model_path, "clip-int8.pth") clip_quant_ckpt = os.path.join(clip_path, f"clip-{clip_quant_scheme}.pth")
else:
clip_quant_ckpt = os.path.join(model_path, "clip-fp8.pth")
else: else:
clip_quant_ckpt = None clip_quant_ckpt = None
...@@ -298,16 +300,22 @@ def run_inference( ...@@ -298,16 +300,22 @@ def run_inference(
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F" mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F"
dit_quantized_ckpt = os.path.join(model_path, dit_quant_scheme) dit_quantized_ckpt = os.path.join(model_path, dit_quant_scheme)
if os.path.exists(os.path.join(dit_quantized_ckpt, "config.json")):
with open(os.path.join(dit_quantized_ckpt, "config.json"), "r") as f:
quant_model_config = json.load(f)
else: else:
mm_type = "Default" mm_type = "Default"
dit_quantized_ckpt = None dit_quantized_ckpt = None
quant_model_config = {}
config = { config = {
"infer_steps": infer_steps, "infer_steps": infer_steps,
"target_video_length": num_frames, "target_video_length": num_frames,
"target_width": int(resolution.split("x")[0]), "target_width": int(resolution.split("x")[0]),
"target_height": int(resolution.split("x")[1]), "target_height": int(resolution.split("x")[1]),
"attention_type": attention_type, "self_attn_1_type": attention_type,
"cross_attn_1_type": attention_type,
"cross_attn_2_type": attention_type,
"seed": seed, "seed": seed,
"enable_cfg": enable_cfg, "enable_cfg": enable_cfg,
"sample_guide_scale": cfg_scale, "sample_guide_scale": cfg_scale,
...@@ -365,6 +373,7 @@ def run_inference( ...@@ -365,6 +373,7 @@ def run_inference(
config = EasyDict(config) config = EasyDict(config)
config["mode"] = "infer" config["mode"] = "infer"
config.update(model_config) config.update(model_config)
config.update(quant_model_config)
logger.info(f"使用模型: {model_path}") logger.info(f"使用模型: {model_path}")
logger.info(f"推理配置:\n{json.dumps(config, indent=4, ensure_ascii=False)}") logger.info(f"推理配置:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
...@@ -588,13 +597,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -588,13 +597,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
(32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}), (32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}),
( (
16, 16,
{ {"dit_quant_scheme_val": quant_type, "t5_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type, "lazy_load_val": True},
"dit_quant_scheme_val": quant_type,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"lazy_load_val": True,
"dit_quant_scheme_val": quant_type,
},
), ),
] ]
...@@ -784,7 +787,7 @@ def main(): ...@@ -784,7 +787,7 @@ def main():
elem_classes=["output-video"], elem_classes=["output-video"],
) )
infer_btn = gr.Button("生成视频", variant="primary", size="lg") infer_btn = gr.Button("生成视频", variant="primary", size="lg")
with gr.Tab("⚙️ 高级选项", id=2): with gr.Tab("⚙️ 高级选项", id=2):
with gr.Group(elem_classes="advanced-options"): with gr.Group(elem_classes="advanced-options"):
...@@ -894,10 +897,10 @@ def main(): ...@@ -894,10 +897,10 @@ def main():
info="Clip编码器的推理精度", info="Clip编码器的推理精度",
) )
precision_mode = gr.Dropdown( precision_mode = gr.Dropdown(
label="精度模式", label="敏感层精度",
choices=["fp32", "bf16"], choices=["fp32", "bf16"],
value="fp32", value="fp32",
info="部分敏感层的推理精度", info="选择用于敏感层(如norm层和embedding层)的数值精度",
) )
gr.Markdown("### 变分自编码器(VAE)") gr.Markdown("### 变分自编码器(VAE)")
...@@ -960,6 +963,8 @@ def main(): ...@@ -960,6 +963,8 @@ def main():
], ],
) )
dit_quant_scheme.change(fn=update_precision_mode, inputs=[dit_quant_scheme], outputs=[precision_mode])
infer_btn.click( infer_btn.click(
fn=run_inference, fn=run_inference,
inputs=[ inputs=[
......
#!/bin/bash #!/bin/bash
lightx2v_path=/path/to/lightx2v lightx2v_path=/mtc/gushiqiao/llmc_workspace/lightx2v_new/lightx2v
model_path=/path/to/wan model_path=/data/nvme0/gushiqiao/models/Wan2.1-I2V-14B-480P-Lightx2v
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=7
export CUDA_LAUNCH_BLOCKING=1 export CUDA_LAUNCH_BLOCKING=1
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
......
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"quant_method": "smoothquant"
},
"dit_quantized_ckpt": "/path/to/dit_int8"
}
import torch import torch
import flashinfer
try:
import flashinfer
except ImportError:
flashinfer = None
### ###
### Code from radial-attention ### Code from radial-attention
......
...@@ -131,7 +131,18 @@ class SageAttn2Weight(AttnWeightTemplate): ...@@ -131,7 +131,18 @@ class SageAttn2Weight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None): def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous() q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if model_cls == "hunyuan": if model_cls == "hunyuan":
x1 = sageattn( x1 = sageattn(
......
...@@ -120,7 +120,7 @@ class WanModel: ...@@ -120,7 +120,7 @@ class WanModel:
def _load_quant_split_ckpt(self, use_bf16, skip_bf16): def _load_quant_split_ckpt(self, use_bf16, skip_bf16):
lazy_load_model_path = self.config.dit_quantized_ckpt lazy_load_model_path = self.config.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}") logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict, transformer_weight_dict = {}, {} pre_post_weight_dict = {}
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors") safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device="cpu") as f: with safe_open(safetensor_path, framework="pt", device="cpu") as f:
......
...@@ -32,7 +32,7 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -32,7 +32,7 @@ class WanTransformerAttentionBlock(WeightModule):
self.mm_type = mm_type self.mm_type = mm_type
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config["mm_config"].get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False) self.sparge = config.get("sparge", False)
self.lazy_load = self.config.get("lazy_load", False) self.lazy_load = self.config.get("lazy_load", False)
...@@ -89,7 +89,7 @@ class WanModulation(WeightModule): ...@@ -89,7 +89,7 @@ class WanModulation(WeightModule):
self.mm_type = mm_type self.mm_type = mm_type
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config["mm_config"].get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False) self.sparge = config.get("sparge", False)
self.lazy_load = lazy_load self.lazy_load = lazy_load
...@@ -112,7 +112,7 @@ class WanSelfAttention(WeightModule): ...@@ -112,7 +112,7 @@ class WanSelfAttention(WeightModule):
self.mm_type = mm_type self.mm_type = mm_type
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config["mm_config"].get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False) self.sparge = config.get("sparge", False)
self.lazy_load = lazy_load self.lazy_load = lazy_load
...@@ -185,7 +185,7 @@ class WanSelfAttention(WeightModule): ...@@ -185,7 +185,7 @@ class WanSelfAttention(WeightModule):
self.self_attn_1.load(sparge_ckpt) self.self_attn_1.load(sparge_ckpt)
else: else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]()) self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]())
if self.quant_method in ["smoothquant", "awq"]: if self.quant_method in ["advanced_ptq"]:
self.add_module( self.add_module(
"smooth_norm1_weight", "smooth_norm1_weight",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
...@@ -314,7 +314,7 @@ class WanFFN(WeightModule): ...@@ -314,7 +314,7 @@ class WanFFN(WeightModule):
self.mm_type = mm_type self.mm_type = mm_type
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config["mm_config"].get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
...@@ -342,7 +342,7 @@ class WanFFN(WeightModule): ...@@ -342,7 +342,7 @@ class WanFFN(WeightModule):
), ),
) )
if self.quant_method in ["smoothquant", "awq"]: if self.quant_method in ["advanced_ptq"]:
self.add_module( self.add_module(
"smooth_norm2_weight", "smooth_norm2_weight",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
......
...@@ -10,7 +10,8 @@ Facilitates mutual conversion between diffusers architecture and lightx2v archit ...@@ -10,7 +10,8 @@ Facilitates mutual conversion between diffusers architecture and lightx2v archit
python converter.py \ python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P \
--output /Path/To/Wan2.1-I2V-14B-480P-Diffusers \ --output /Path/To/Wan2.1-I2V-14B-480P-Diffusers \
--direction forward --direction forward \
--save_by_block
``` ```
### Diffusers->Lightx2v ### Diffusers->Lightx2v
...@@ -18,7 +19,8 @@ python converter.py \ ...@@ -18,7 +19,8 @@ python converter.py \
python converter.py \ python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \
--output /Path/To/Wan2.1-I2V-14B-480P \ --output /Path/To/Wan2.1-I2V-14B-480P \
--direction backward --direction backward \
--save_by_block
``` ```
...@@ -30,31 +32,32 @@ This tool supports converting fp32/fp16/bf16 model weights to INT8、FP8 type. ...@@ -30,31 +32,32 @@ This tool supports converting fp32/fp16/bf16 model weights to INT8、FP8 type.
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .safetensors \ --output_ext .safetensors \
--output_name wan_int8 \ --output_name wan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_dit --model_type wan_dit \
--quantized \
--save_by_block
``` ```
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .safetensors \ --output_ext .safetensors \
--output_name wan_fp8 \ --output_name wan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type wan_dit --model_type wan_dit \
--quantized \
--save_by_block
``` ```
### Wan DiT + LoRA ### Wan DiT + LoRA
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-T2V-14B/ \ --source /Path/To/Wan-AI/Wan2.1-T2V-14B/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .safetensors \ --output_ext .safetensors \
...@@ -62,31 +65,33 @@ python converter.py \ ...@@ -62,31 +65,33 @@ python converter.py \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_dit \ --model_type wan_dit \
--lora_path /Path/To/LoRA1/ /Path/To/LoRA2/ \ --lora_path /Path/To/LoRA1/ /Path/To/LoRA2/ \
--lora_alpha 1.0 1.0 --lora_alpha 1.0 1.0 \
--quantized \
--save_by_block
``` ```
### Hunyuan DIT ### Hunyuan DIT
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .safetensors \ --output_ext ..safetensors \
--output_name hunyuan_int8 \ --output_name hunyuan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type hunyuan_dit --model_type hunyuan_dit \
--quantized
``` ```
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .safetensors \ --output_ext .safetensors \
--output_name hunyuan_fp8 \ --output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type hunyuan_dit --model_type hunyuan_dit \
--quantized
``` ```
...@@ -94,24 +99,24 @@ python converter.py \ ...@@ -94,24 +99,24 @@ python converter.py \
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .pth\
--output_name models_t5_umt5-xxl-enc-int8 \ --output_name models_t5_umt5-xxl-enc-int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_t5 --model_type wan_t5 \
--quantized
``` ```
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .pth\
--output_name models_t5_umt5-xxl-enc-fp8 \ --output_name models_t5_umt5-xxl-enc-fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type wan_t5 --model_type wan_t5 \
--quantized
``` ```
...@@ -120,21 +125,21 @@ python converter.py \ ...@@ -120,21 +125,21 @@ python converter.py \
```bash ```bash
python converter.py \ python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--quantized \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth \ --output_ext .pth \
--output_name clip_int8 \ --output_name clip-int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_clip --model_type wan_clip \
--quantized
``` ```
```bash ```bash
python converter.py \ python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--quantized \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth \ --output_ext .pth \
--output_name clip_fp8 \ --output_name clip-fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type wan_clip --model_type wan_clip \
--quantized
``` ```
...@@ -10,7 +10,8 @@ ...@@ -10,7 +10,8 @@
python converter.py \ python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P \
--output /Path/To/Wan2.1-I2V-14B-480P-Diffusers \ --output /Path/To/Wan2.1-I2V-14B-480P-Diffusers \
--direction forward --direction forward \
--save_by_block
``` ```
### Diffusers->Lightx2v ### Diffusers->Lightx2v
...@@ -18,7 +19,8 @@ python converter.py \ ...@@ -18,7 +19,8 @@ python converter.py \
python converter.py \ python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \
--output /Path/To/Wan2.1-I2V-14B-480P \ --output /Path/To/Wan2.1-I2V-14B-480P \
--direction backward --direction backward \
--save_by_block
``` ```
...@@ -30,31 +32,32 @@ python converter.py \ ...@@ -30,31 +32,32 @@ python converter.py \
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .safetensors \ --output_ext .safetensors \
--output_name wan_int8 \ --output_name wan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_dit --model_type wan_dit \
--quantized \
--save_by_block
``` ```
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .safetensors \ --output_ext .safetensors \
--output_name wan_fp8 \ --output_name wan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type wan_dit --model_type wan_dit \
--quantized \
--save_by_block
``` ```
### Wan DiT + LoRA ### Wan DiT + LoRA
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-T2V-14B/ \ --source /Path/To/Wan-AI/Wan2.1-T2V-14B/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .safetensors \ --output_ext .safetensors \
...@@ -62,31 +65,33 @@ python converter.py \ ...@@ -62,31 +65,33 @@ python converter.py \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_dit \ --model_type wan_dit \
--lora_path /Path/To/LoRA1/ /Path/To/LoRA2/ \ --lora_path /Path/To/LoRA1/ /Path/To/LoRA2/ \
--lora_alpha 1.0 1.0 --lora_alpha 1.0 1.0 \
--quantized \
--save_by_block
``` ```
### Hunyuan DIT ### Hunyuan DIT
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext ..safetensors \ --output_ext ..safetensors \
--output_name hunyuan_int8 \ --output_name hunyuan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type hunyuan_dit --model_type hunyuan_dit \
--quantized
``` ```
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .safetensors \ --output_ext .safetensors \
--output_name hunyuan_fp8 \ --output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type hunyuan_dit --model_type hunyuan_dit \
--quantized
``` ```
...@@ -94,24 +99,24 @@ python converter.py \ ...@@ -94,24 +99,24 @@ python converter.py \
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .pth\
--output_name models_t5_umt5-xxl-enc-int8 \ --output_name models_t5_umt5-xxl-enc-int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_t5 --model_type wan_t5 \
--quantized
``` ```
```bash ```bash
python converter.py \ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .pth\
--output_name models_t5_umt5-xxl-enc-fp8 \ --output_name models_t5_umt5-xxl-enc-fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type wan_t5 --model_type wan_t5 \
--quantized
``` ```
...@@ -120,21 +125,21 @@ python converter.py \ ...@@ -120,21 +125,21 @@ python converter.py \
```bash ```bash
python converter.py \ python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--quantized \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth \ --output_ext .pth \
--output_name clip_int8 \ --output_name clip-int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_clip --model_type wan_clip \
--quantized
``` ```
```bash ```bash
python converter.py \ python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--quantized \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth \ --output_ext .pth \
--output_name clip_fp8 \ --output_name clip-fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type wan_clip --model_type wan_clip \
--quantized
``` ```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment