Unverified Commit e106ff67 authored by Bilang ZHANG's avatar Bilang ZHANG Committed by GitHub
Browse files

support lightx2v-kernel and update convert.py (#413)

parent 3aab9893
...@@ -3,9 +3,27 @@ from abc import ABCMeta, abstractmethod ...@@ -3,9 +3,27 @@ from abc import ABCMeta, abstractmethod
import torch import torch
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.global_paras import CALIB
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
try:
from lightx2v_kernel.gemm import (
cutlass_scaled_mxfp4_mm,
cutlass_scaled_mxfp6_mxfp8_mm,
cutlass_scaled_mxfp8_mm,
cutlass_scaled_nvfp4_mm,
scaled_mxfp4_quant,
scaled_mxfp6_quant,
scaled_mxfp8_quant,
scaled_nvfp4_quant,
)
except ImportError:
scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm = None, None
scaled_mxfp4_quant, cutlass_scaled_mxfp4_mm = None, None
scaled_mxfp6_quant, cutlass_scaled_mxfp6_mxfp8_mm = None, None
scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm = None, None
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
except ImportError: except ImportError:
...@@ -267,6 +285,179 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -267,6 +285,179 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.bias = None self.bias = None
self.pin_bias = None self.pin_bias = None
def load_mxfp4(self, weight_dict):
if self.config.get("weight_auto_quant", False):
device = weight_dict[self.weight_name].device
self.weight = weight_dict[self.weight_name].cuda().to(torch.bfloat16)
self.weight, self.weight_scale = scaled_mxfp4_quant(self.weight)
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else:
device = weight_dict[self.weight_name].device
if device.type == "cuda":
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape
weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
if self.bias_name is not None:
device = weight_dict[self.bias_name].device
if device.type == "cuda":
self.bias = weight_dict[self.bias_name]
elif device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
else:
self.bias = None
self.pin_bias = None
def load_mxfp6(self, weight_dict):
if self.config.get("weight_auto_quant", False):
device = weight_dict[self.weight_name].device
self.weight = weight_dict[self.weight_name].cuda().to(torch.bfloat16)
self.weight, self.weight_scale = scaled_mxfp6_quant(self.weight)
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else:
device = weight_dict[self.weight_name].device
if device.type == "cuda":
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape
weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
if self.bias_name is not None:
device = weight_dict[self.bias_name].device
if device.type == "cuda":
self.bias = weight_dict[self.bias_name]
elif device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
else:
self.bias = None
self.pin_bias = None
def load_mxfp8(self, weight_dict):
if self.config.get("weight_auto_quant", False):
device = weight_dict[self.weight_name].device
self.weight = weight_dict[self.weight_name].cuda().to(torch.bfloat16)
self.weight, self.weight_scale = scaled_mxfp8_quant(self.weight)
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else:
device = weight_dict[self.weight_name].device
if device.type == "cuda":
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape
weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
if self.bias_name is not None:
device = weight_dict[self.bias_name].device
if device.type == "cuda":
self.bias = weight_dict[self.bias_name]
elif device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
else:
self.bias = None
self.pin_bias = None
def load_nvfp4(self, weight_dict):
device = weight_dict[self.weight_name].device
input_absmax = weight_dict[self.weight_name.replace(".weight", ".input_absmax")]
input_global_scale = (2688.0 / input_absmax).to(torch.float32)
weight_global_scale = weight_dict[f"{self.weight_name}_global_scale"]
alpha = 1.0 / (input_global_scale * weight_global_scale)
if device.type == "cuda":
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
self.input_global_scale = input_global_scale
self.alpha = alpha
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape
weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
input_global_scale_shape = input_global_scale.shape
input_global_scale_dtype = input_global_scale.dtype
self.pin_input_global_scale = torch.empty(input_global_scale_shape, pin_memory=True, dtype=input_global_scale_dtype)
self.pin_input_global_scale.copy_(input_global_scale)
alpha_shape = alpha.shape
alpha_dtype = alpha.dtype
self.pin_alpha = torch.empty(alpha_shape, pin_memory=True, dtype=alpha_dtype)
self.pin_alpha.copy_(alpha)
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
if self.bias_name is not None:
device = weight_dict[self.bias_name].device
if device.type == "cuda":
self.bias = weight_dict[self.bias_name]
elif device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
else:
self.bias = None
self.pin_bias = None
def load_fp8_perblock128_sym(self, weight_dict): def load_fp8_perblock128_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False): if self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
...@@ -325,6 +516,18 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -325,6 +516,18 @@ class MMWeightQuantTemplate(MMWeightTemplate):
input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
return input_tensor_quant, input_tensor_scale return input_tensor_quant, input_tensor_scale
def act_quant_nvfp4(self, x):
input_tensor_quant, input_tensor_scale = scaled_nvfp4_quant(x, self.input_global_scale)
return input_tensor_quant, input_tensor_scale
def act_quant_mxfp4(self, x):
input_tensor_quant, input_tensor_scale = scaled_mxfp4_quant(x)
return input_tensor_quant, input_tensor_scale
def act_quant_mxfp8(self, x):
input_tensor_quant, input_tensor_scale = scaled_mxfp8_quant(x)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x): def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
assert x.dim() == 2 and x.size(1) % 128 == 0 assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape m, n = x.shape
...@@ -431,6 +634,170 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate): ...@@ -431,6 +634,170 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
return output_tensor return output_tensor
@MM_WEIGHT_REGISTER("mxfp4")
class MMWeightWmxfp4Amxfp4dynamic(MMWeightQuantTemplate):
"""
Name: W-mxfp4-A-mxfp4-dynamic
Quant MM:
Weight: mxfp4
Act: mxfp4
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_mxfp4
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_mxfp4
self.set_alpha()
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32)
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
self.alpha = self.alpha.to(self.weight.device)
output_tensor = cutlass_scaled_mxfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("mxfp6-mxfp8")
class MMWeightWmxfp6Amxfp8dynamic(MMWeightQuantTemplate):
"""
Name: W-mxfp6-A-nvfp8-dynamic
Quant MM:
Weight: mxfp6
Act: mxfp8
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_mxfp6
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_mxfp8
self.set_alpha()
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32)
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
self.alpha = self.alpha.to(self.weight.device)
output_tensor = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("mxfp8")
class MMWeightWmxfp8Amxfp8dynamic(MMWeightQuantTemplate):
"""
Name: W-mxfp8-A-nvfp8-dynamic
Quant MM:
Weight: mxfp8
Act: mxfp8
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_mxfp8
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_mxfp8
self.set_alpha()
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32)
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
self.alpha = self.alpha.to(self.weight.device)
output_tensor = cutlass_scaled_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("nvfp4")
class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate):
"""
Name: W-nvfp4-A-nvfp4-dynamic
Quant MM:
Weight: nvfp4
Act: nvfp4
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_nvfp4
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_nvfp4
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = cutlass_scaled_nvfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
if hasattr(self, "pin_weight_scale"):
self.weight_scale = self.pin_weight_scale.cuda(non_blocking=non_blocking)
self.input_global_scale = self.pin_input_global_scale.cuda(non_blocking=non_blocking)
self.alpha = self.pin_alpha.cuda(non_blocking=non_blocking)
if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.pin_bias.cuda(non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
if hasattr(self, "weight_scale_name"):
self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
self.input_global_scale = self.pin_input_global_scale.copy_(self.input_global_scale, non_blocking=non_blocking).cpu()
self.alpha = self.pin_alpha.copy_(self.alpha, non_blocking=non_blocking).cpu()
if self.bias is not None:
self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
self.input_global_scale = self.input_global_scale.to("cpu", non_blocking=non_blocking)
self.alpha = self.alpha.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
@MM_WEIGHT_REGISTER("Calib")
class MMCalibNvfp4(MMWeight):
"""
Name: calib
Calib:
absmax: torch.max(torch.abs(input_tensor))
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.running_absmax = None
self.count = 0
self.decay = 0.9
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype, device = input_tensor.dtype, input_tensor.device
current_absmax = torch.max(torch.abs(input_tensor)).to("cpu")
if self.count % 2 == 0:
if self.running_absmax is None:
self.running_absmax = current_absmax
else:
self.running_absmax = self.decay * self.running_absmax + (1 - self.decay) * current_absmax
CALIB["absmax"][self.weight_name] = self.running_absmax
self.count = self.count + 1
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
if self.bias is None:
return torch.mm(input_tensor, self.weight, out=output_tensor)
return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)
@MM_WEIGHT_REGISTER("fp8-q8f") @MM_WEIGHT_REGISTER("fp8-q8f")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate): class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
""" """
......
...@@ -60,7 +60,21 @@ class WanModel(CompiledMethodsMixin): ...@@ -60,7 +60,21 @@ class WanModel(CompiledMethodsMixin):
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.get("dit_quantized", False) self.dit_quantized = self.config.get("dit_quantized", False)
if self.dit_quantized: if self.dit_quantized:
assert self.config.get("dit_quant_scheme", "Default") in ["Default-Force-FP32", "fp8-vllm", "int8-vllm", "fp8-q8f", "int8-q8f", "fp8-b128-deepgemm", "fp8-sgl", "int8-sgl", "int8-torchao"] assert self.config.get("dit_quant_scheme", "Default") in [
"Default-Force-FP32",
"fp8-vllm",
"int8-vllm",
"fp8-q8f",
"int8-q8f",
"fp8-b128-deepgemm",
"fp8-sgl",
"int8-sgl",
"int8-torchao",
"nvfp4",
"mxfp4",
"mxfp6-mxfp8",
"mxfp8",
]
self.device = device self.device = device
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
...@@ -169,6 +183,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -169,6 +183,7 @@ class WanModel(CompiledMethodsMixin):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else: else:
safetensors_files = [safetensors_path] safetensors_files = [safetensors_path]
safetensors_path = os.path.dirname(safetensors_path)
weight_dict = {} weight_dict = {}
for safetensor_path in safetensors_files: for safetensor_path in safetensors_files:
...@@ -192,6 +207,13 @@ class WanModel(CompiledMethodsMixin): ...@@ -192,6 +207,13 @@ class WanModel(CompiledMethodsMixin):
else: else:
weight_dict[k] = f.get_tensor(k).to(self.device) weight_dict[k] = f.get_tensor(k).to(self.device)
if self.config.get("dit_quant_scheme", "Default") == "nvfp4":
calib_path = os.path.join(safetensors_path, "calib.pt")
logger.info(f"[CALIB] Loaded calibration data from: {calib_path}")
calib_data = torch.load(calib_path, map_location="cpu")
for k, v in calib_data["absmax"].items():
weight_dict[k.replace(".weight", ".input_absmax")] = v.to(self.device)
return weight_dict return weight_dict
def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer): # Need rewrite def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer): # Need rewrite
......
...@@ -21,6 +21,8 @@ class WanTransformerWeights(WeightModule): ...@@ -21,6 +21,8 @@ class WanTransformerWeights(WeightModule):
self.mm_type = config.get("dit_quant_scheme", "Default") self.mm_type = config.get("dit_quant_scheme", "Default")
if self.mm_type != "Default": if self.mm_type != "Default":
assert config.get("dit_quantized") is True assert config.get("dit_quantized") is True
if config.get("do_mm_calib", False):
self.mm_type = "Calib"
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)]) self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)])
self.add_module("blocks", self.blocks) self.add_module("blocks", self.blocks)
......
...@@ -11,6 +11,7 @@ from requests.exceptions import RequestException ...@@ -11,6 +11,7 @@ from requests.exceptions import RequestException
from lightx2v.server.metrics import monitor_cli from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.generate_task_id import generate_task_id from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.global_paras import CALIB
from lightx2v.utils.memory_profiler import peak_memory_decorator from lightx2v.utils.memory_profiler import peak_memory_decorator
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
...@@ -176,6 +177,10 @@ class DefaultRunner(BaseRunner): ...@@ -176,6 +177,10 @@ class DefaultRunner(BaseRunner):
self.model.transformer_weights.clear() self.model.transformer_weights.clear()
self.model.pre_weight.clear() self.model.pre_weight.clear()
del self.model del self.model
if self.config.get("do_mm_calib", False):
calib_path = os.path.join(os.getcwd(), "calib.pt")
torch.save(CALIB, calib_path)
logger.info(f"[CALIB] Saved calibration data successfully to: {calib_path}")
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
...@@ -258,6 +263,7 @@ class DefaultRunner(BaseRunner): ...@@ -258,6 +263,7 @@ class DefaultRunner(BaseRunner):
def init_run(self): def init_run(self):
self.gen_video_final = None self.gen_video_final = None
self.get_video_segment_num() self.get_video_segment_num()
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
......
CALIB = {"absmax": {}}
...@@ -51,5 +51,6 @@ LN_WEIGHT_REGISTER = Register() ...@@ -51,5 +51,6 @@ LN_WEIGHT_REGISTER = Register()
CONV3D_WEIGHT_REGISTER = Register() CONV3D_WEIGHT_REGISTER = Register()
CONV2D_WEIGHT_REGISTER = Register() CONV2D_WEIGHT_REGISTER = Register()
TENSOR_REGISTER = Register() TENSOR_REGISTER = Register()
CONVERT_WEIGHT_REGISTER = Register()
RUNNER_REGISTER = Register() RUNNER_REGISTER = Register()
...@@ -11,11 +11,21 @@ from concurrent.futures import ThreadPoolExecutor, as_completed ...@@ -11,11 +11,21 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import torch import torch
from loguru import logger from loguru import logger
from lora_loader import LoRALoader
try:
from lora_loader import LoRALoader
except ImportError:
pass
from safetensors import safe_open from safetensors import safe_open
from safetensors import torch as st from safetensors import torch as st
from tqdm import tqdm from tqdm import tqdm
try:
from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER
except ImportError:
pass
from tools.convert.quant import *
def get_key_mapping_rules(direction, model_type): def get_key_mapping_rules(direction, model_type):
if model_type == "wan_dit": if model_type == "wan_dit":
...@@ -349,7 +359,17 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8, comfyui_mode=False): ...@@ -349,7 +359,17 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8, comfyui_mode=False):
def quantize_model( def quantize_model(
weights, w_bit=8, target_keys=["attn", "ffn"], adapter_keys=None, key_idx=2, ignore_key=None, linear_dtype=torch.int8, non_linear_dtype=torch.float, comfyui_mode=False, comfyui_keys=[] weights,
w_bit=8,
target_keys=["attn", "ffn"],
adapter_keys=None,
key_idx=2,
ignore_key=None,
linear_dtype=torch.int8,
non_linear_dtype=torch.float,
comfyui_mode=False,
comfyui_keys=[],
linear_quant_type=None,
): ):
""" """
Quantize model weights in-place Quantize model weights in-place
...@@ -414,7 +434,13 @@ def quantize_model( ...@@ -414,7 +434,13 @@ def quantize_model(
original_size += original_tensor_size original_size += original_tensor_size
# Quantize tensor and store results # Quantize tensor and store results
w_q, scales = quantize_tensor(tensor, w_bit, linear_dtype, comfyui_mode) if linear_quant_type:
quantizer = CONVERT_WEIGHT_REGISTER[linear_quant_type](tensor)
w_q, scales, extra = quantizer.weight_quant_func(tensor)
weight_global_scale = extra.get("weight_global_scale", None) # For nvfp4
else:
w_q, scales = quantize_tensor(tensor, w_bit, linear_dtype, comfyui_mode)
weight_global_scale = None
# Replace original tensor and store scales # Replace original tensor and store scales
weights[key] = w_q weights[key] = w_q
...@@ -422,6 +448,8 @@ def quantize_model( ...@@ -422,6 +448,8 @@ def quantize_model(
weights[key.replace(".weight", ".scale_weight")] = scales weights[key.replace(".weight", ".scale_weight")] = scales
else: else:
weights[key + "_scale"] = scales weights[key + "_scale"] = scales
if weight_global_scale:
weights[key + "_global_scale"] = weight_global_scale
quantized_tensor_size = w_q.numel() * w_q.element_size() quantized_tensor_size = w_q.numel() * w_q.element_size()
scale_size = scales.numel() * scales.element_size() scale_size = scales.numel() * scales.element_size()
...@@ -622,6 +650,7 @@ def convert_weights(args): ...@@ -622,6 +650,7 @@ def convert_weights(args):
non_linear_dtype=args.non_linear_dtype, non_linear_dtype=args.non_linear_dtype,
comfyui_mode=args.comfyui_mode, comfyui_mode=args.comfyui_mode,
comfyui_keys=args.comfyui_keys, comfyui_keys=args.comfyui_keys,
linear_quant_type=args.linear_quant_type,
) )
os.makedirs(args.output, exist_ok=True) os.makedirs(args.output, exist_ok=True)
...@@ -793,6 +822,12 @@ def main(): ...@@ -793,6 +822,12 @@ def main():
choices=["torch.int8", "torch.float8_e4m3fn"], choices=["torch.int8", "torch.float8_e4m3fn"],
help="Data type for linear", help="Data type for linear",
) )
parser.add_argument(
"--linear_quant_type",
type=str,
choices=["INT8", "FP8", "NVFP4", "MXFP4", "MXFP6", "MXFP8"],
help="Data type for linear",
)
parser.add_argument( parser.add_argument(
"--non_linear_dtype", "--non_linear_dtype",
type=str, type=str,
......
from abc import ABCMeta
import torch
from qtorch.quant import float_quantize
try:
from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER
from lightx2v_kernel.gemm import scaled_mxfp4_quant, scaled_mxfp6_quant, scaled_mxfp8_quant, scaled_nvfp4_quant
except ImportError:
pass
class QuantTemplate(metaclass=ABCMeta):
def __init__(self, weight):
if weight.dim() != 2:
raise ValueError(f"Only 2D tensors supported. Got {weight.dim()}D tensor")
if torch.isnan(weight).any():
raise ValueError("Tensor contains NaN values")
self.weight_quant_func = None
self.extra = {}
@CONVERT_WEIGHT_REGISTER("INT8")
class QuantWeightINT8(QuantTemplate):
def __init__(self, weight):
super().__init__(weight)
self.weight_quant_func = self.load_int8_weight
@torch.no_grad()
def load_int8_weight(self, w):
org_w_shape = w.shape
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
qmin, qmax = -128, 127
scales = max_val / qmax
w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(torch.int8)
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w_q).sum() == 0
scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("FP8")
class QuantWeightFP8(QuantTemplate):
def __init__(self, weight):
super().__init__(weight)
self.weight_quant_func = self.load_fp8_weight
@torch.no_grad()
def load_fp8_weight(self, w):
org_w_shape = w.shape
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
finfo = torch.finfo(torch.float8_e4m3fn)
qmin, qmax = finfo.min, finfo.max
scales = max_val / qmax
scaled_tensor = w / scales
scaled_tensor = torch.clip(scaled_tensor, qmin, qmax)
w_q = float_quantize(scaled_tensor.float(), 4, 3, rounding="nearest").to(torch.float8_e4m3fn)
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w_q).sum() == 0
scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("MXFP4")
class QuantWeightMxFP4(QuantTemplate):
def __init__(self, weight):
super().__init__(weight)
self.weight_quant_func = self.load_mxfp4_weight
@torch.no_grad()
def load_mxfp4_weight(self, w):
device = w.device
w = w.cuda().to(torch.bfloat16)
w_q, scales = scaled_mxfp4_quant(w)
w_q, scales = w_q.to(device), scales.to(device)
return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("MXFP6")
class QuantWeightMxFP6(QuantTemplate):
def __init__(self, weight):
super().__init__(weight)
self.weight_quant_func = self.load_mxfp6_weight
@torch.no_grad()
def load_mxfp6_weight(self, w):
device = w.device
w = w.cuda().to(torch.bfloat16)
w_q, scales = scaled_mxfp6_quant(w)
w_q, scales = w_q.to(device), scales.to(device)
return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("MXFP8")
class QuantWeightMxFP8(QuantTemplate):
def __init__(self, weight):
super().__init__(weight)
self.weight_quant_func = self.load_mxfp8_weight
@torch.no_grad()
def load_mxfp8_weight(self, w):
device = w.device
w = w.cuda().to(torch.bfloat16)
w_q, scales = scaled_mxfp8_quant(w)
w_q, scales = w_q.to(device), scales.to(device)
return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("NVFP4")
class QuantWeightNVFP4(QuantTemplate):
def __init__(self, weight):
super().__init__(weight)
self.weight_quant_func = self.load_fp4_weight
@torch.no_grad()
def load_fp4_weight(self, w):
device = w.device
w = w.cuda().to(torch.bfloat16)
weight_global_scale = (2688.0 / torch.max(torch.abs(w))).to(torch.float32)
w_q, scales = scaled_nvfp4_quant(w, weight_global_scale)
w_q, scales = w_q.to(device), scales.to(device)
self.extra["weight_global_scale"] = weight_global_scale.to(device)
return w_q, scales, self.extra
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment