Unverified Commit d0a5c78d authored by Bilang ZHANG's avatar Bilang ZHANG Committed by GitHub
Browse files

convert: adapt to seko_talk (#418)

parent 503c3abc
......@@ -32,6 +32,12 @@ class WanAudioModel(WanModel):
adapter_model_name = "audio_adapter_model_fp8.safetensors"
elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-sgl"]:
adapter_model_name = "audio_adapter_model_int8.safetensors"
elif self.config.get("adapter_quant_scheme", None) in ["mxfp4"]:
adapter_model_name = "audio_adapter_model_mxfp4.safetensors"
elif self.config.get("adapter_quant_scheme", None) in ["mxfp6", "mxfp6-mxfp8"]:
adapter_model_name = "audio_adapter_model_mxfp6.safetensors"
elif self.config.get("adapter_quant_scheme", None) in ["mxfp8"]:
adapter_model_name = "audio_adapter_model_mxfp8.safetensors"
else:
raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
else:
......
......@@ -16,14 +16,15 @@ try:
from lora_loader import LoRALoader
except ImportError:
pass
import sys
from pathlib import Path
from safetensors import safe_open
from safetensors import torch as st
from tqdm import tqdm
try:
from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER
except ImportError:
pass
sys.path.append(str(Path(__file__).parent.parent.parent))
from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER
from tools.convert.quant import *
......
......@@ -3,8 +3,9 @@ from abc import ABCMeta
import torch
from qtorch.quant import float_quantize
from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER
try:
from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER
from lightx2v_kernel.gemm import scaled_mxfp4_quant, scaled_mxfp6_quant, scaled_mxfp8_quant, scaled_nvfp4_quant
except ImportError:
pass
......
......@@ -9,6 +9,7 @@ from safetensors.torch import save_file
sys.path.append(str(Path(__file__).parent.parent.parent))
from lightx2v.utils.quant_utils import FloatQuantizer
from tools.convert.quant import *
def main():
......@@ -47,12 +48,17 @@ def main():
if key.startswith("ca") and ".to" in key and "weight" in key:
print(f"Converting {key} to FP8, dtype: {state_dict[key].dtype}")
weight = state_dict[key].to(torch.float32).cuda()
## fp8
w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight)
weight = weight.to(torch.float8_e4m3fn)
weight_scale = weight_scale.to(torch.float32)
## QuantWeightMxFP4, QuantWeightMxFP6, QuantWeightMxFP8 for mxfp4,mxfp6,mxfp8
# weight = state_dict[key].to(torch.bfloat16).cuda()
# quantizer = QuantWeightMxFP4(weight)
# weight, weight_scale, _ = quantizer.weight_quant_func(weight)
new_state_dict[key] = weight.cpu()
new_state_dict[key + "_scale"] = weight_scale.cpu()
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment