Commit 420fec7f authored by helloyongyang's avatar helloyongyang
Browse files

Support save_naive_quant and load quantization weight

parent a81ad1e5
File mode changed from 100644 to 100755
{
"infer_steps": 20,
"target_video_length": 33,
"i2v_resolution": "720p",
"attention_type": "flash_attn3",
"seed": 0,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
},
"naive_quant_path": "./hy_i2v_quant_model"
}
File mode changed from 100644 to 100755
{
"infer_steps": 20,
"target_video_length": 33,
"target_height": 720,
"target_width": 1280,
"attention_type": "flash_attn3",
"seed": 42,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
},
"naive_quant_path": "./hy_t2v_quant_model"
}
File mode changed from 100644 to 100755
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"naive_quant_path": "./wan_i2v_quant_model"
}
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"naive_quant_path": "./wan_t2v_quant_model"
}
......@@ -24,15 +24,15 @@ class WeightModule:
if hasattr(parameter, "load"):
parameter.load(weight_dict)
def state_dict(self, destination=None, prefix=""):
def state_dict(self, destination=None):
if destination is None:
destination = {}
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param.detach().cpu().clone()
param.state_dict(destination)
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + ".")
module.state_dict(destination)
return destination
def named_parameters(self, prefix=""):
......
......@@ -48,3 +48,11 @@ class Conv2dWeight(Conv2dWeightTemplate):
self.weight = self.weight.cuda(non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
......@@ -49,6 +49,14 @@ class Conv3dWeight(Conv3dWeightTemplate):
if self.bias is not None:
self.bias = self.bias.cuda()
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
@CONV3D_WEIGHT_REGISTER("Defaultt-Force-BF16")
class Conv3dWeightForceBF16(Conv3dWeight):
......
......@@ -4,6 +4,7 @@ from vllm import _custom_ops as ops
import sgl_kernel
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
from lightx2v.utils.envs import *
from loguru import logger
try:
......@@ -31,9 +32,8 @@ class MMWeightTemplate(metaclass=ABCMeta):
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
def set_config(self, config={}):
self.config = config
def to_cpu(self, non_blocking=False):
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
......@@ -49,6 +49,14 @@ class MMWeightTemplate(metaclass=ABCMeta):
if self.bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
@MM_WEIGHT_REGISTER("Default")
class MMWeight(MMWeightTemplate):
......@@ -56,8 +64,12 @@ class MMWeight(MMWeightTemplate):
super().__init__(weight_name, bias_name)
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].t().cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].t().cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
else:
self.weight = weight_dict[self.weight_name].cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
......@@ -94,39 +106,43 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def load(self, weight_dict):
self.load_func(weight_dict)
if self.weight_need_transpose:
self.weight = self.weight.t()
def load_quantized(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
def load_fp8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", True):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32)
if self.weight_need_transpose:
self.weight = self.weight.t()
else:
self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def load_int8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", True):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = IntegerQuantizer(8, True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32)
if self.weight_need_transpose:
self.weight = self.weight.t()
else:
self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def load_fp8_perblock128_sym(self, weight_dict):
if self.config.get("weight_auto_quant", True):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].cuda()
self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
if self.weight_need_transpose:
self.weight = self.weight.t()
else:
self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
......@@ -174,6 +190,16 @@ class MMWeightQuantTemplate(MMWeightTemplate):
sgl_kernel.sgl_per_token_group_quant_fp8(x, input_tensor_quant, input_tensor_scale, group_size=128, eps=1e-10, fp8_min=-448.0, fp8_max=448.0)
return input_tensor_quant, input_tensor_scale
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
if hasattr(self, "weight_scale"):
destination[self.weight_name.rstrip(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone()
return destination
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
......@@ -452,7 +478,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
if __name__ == "__main__":
weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn).t(),
"xx.bias": torch.randn(8192).to(torch.bfloat16),
"xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
}
......
......@@ -34,6 +34,15 @@ class LNWeightTemplate(metaclass=ABCMeta):
if self.bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
if self.weight is not None:
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
@LN_WEIGHT_REGISTER("Default")
class LNWeight(LNWeightTemplate):
......
......@@ -38,6 +38,12 @@ class RMSWeight(RMSWeightTemplate):
input_tensor = input_tensor * self.weight
return input_tensor
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
return destination
@RMS_WEIGHT_REGISTER("FP32")
class RMSWeightFP32(RMSWeight):
......
......@@ -14,3 +14,9 @@ class DefaultTensor:
def to_cuda(self, non_blocking=False):
self.tensor = self.tensor.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.tensor_name] = self.tensor.cpu().detach().clone()
return destination
import os
import sys
import torch
from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights
from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights
......@@ -10,6 +11,8 @@ from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer im
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import *
from loguru import logger
class HunyuanModel:
......@@ -24,6 +27,11 @@ class HunyuanModel:
self.args = args
self._init_infer_class()
self._init_weights()
if GET_RUNNING_FLAG() == "save_naive_quant":
assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None"
self.save_weights(self.config.naive_quant_path)
sys.exit(0)
self._init_infer()
if config["parallel_attn_type"]:
......@@ -57,8 +65,18 @@ class HunyuanModel:
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"]
return weight_dict
def _load_ckpt_quant_model(self):
assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None"
logger.info(f"Loading quant model from {self.config.naive_quant_path}")
quant_weights_path = os.path.join(self.config.naive_quant_path, "quant_weights.pth")
weight_dict = torch.load(quant_weights_path, map_location=self.device, weights_only=True)
return weight_dict
def _init_weights(self):
weight_dict = self._load_ckpt()
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False):
weight_dict = self._load_ckpt()
else:
weight_dict = self._load_ckpt_quant_model()
# init weights
self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config)
......@@ -73,6 +91,28 @@ class HunyuanModel:
self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
def save_weights(self, save_path):
if not os.path.exists(save_path):
os.makedirs(save_path)
pre_state_dict = self.pre_weight.state_dict()
logger.info(pre_state_dict.keys())
post_state_dict = self.post_weight.state_dict()
logger.info(post_state_dict.keys())
transformer_state_dict = self.transformer_weights.state_dict()
logger.info(transformer_state_dict.keys())
save_dict = {}
save_dict.update(pre_state_dict)
save_dict.update(post_state_dict)
save_dict.update(transformer_state_dict)
save_path = os.path.join(save_path, "quant_weights.pth")
torch.save(save_dict, save_path)
logger.info(f"Save weights to {save_path}")
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment