"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "82491e4b56230bbf22d1ebf75186551d2aa726db"
Commit 420fec7f authored by helloyongyang's avatar helloyongyang
Browse files

Support save_naive_quant and load quantization weight

parent a81ad1e5
File mode changed from 100644 to 100755
{
"infer_steps": 20,
"target_video_length": 33,
"i2v_resolution": "720p",
"attention_type": "flash_attn3",
"seed": 0,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
},
"naive_quant_path": "./hy_i2v_quant_model"
}
File mode changed from 100644 to 100755
{
"infer_steps": 20,
"target_video_length": 33,
"target_height": 720,
"target_width": 1280,
"attention_type": "flash_attn3",
"seed": 42,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
},
"naive_quant_path": "./hy_t2v_quant_model"
}
File mode changed from 100644 to 100755
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"naive_quant_path": "./wan_i2v_quant_model"
}
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"naive_quant_path": "./wan_t2v_quant_model"
}
...@@ -24,15 +24,15 @@ class WeightModule: ...@@ -24,15 +24,15 @@ class WeightModule:
if hasattr(parameter, "load"): if hasattr(parameter, "load"):
parameter.load(weight_dict) parameter.load(weight_dict)
def state_dict(self, destination=None, prefix=""): def state_dict(self, destination=None):
if destination is None: if destination is None:
destination = {} destination = {}
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None: if param is not None:
destination[prefix + name] = param.detach().cpu().clone() param.state_dict(destination)
for name, module in self._modules.items(): for name, module in self._modules.items():
if module is not None: if module is not None:
module.state_dict(destination, prefix + name + ".") module.state_dict(destination)
return destination return destination
def named_parameters(self, prefix=""): def named_parameters(self, prefix=""):
......
...@@ -48,3 +48,11 @@ class Conv2dWeight(Conv2dWeightTemplate): ...@@ -48,3 +48,11 @@ class Conv2dWeight(Conv2dWeightTemplate):
self.weight = self.weight.cuda(non_blocking=non_blocking) self.weight = self.weight.cuda(non_blocking=non_blocking)
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking) self.bias = self.bias.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
...@@ -49,6 +49,14 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -49,6 +49,14 @@ class Conv3dWeight(Conv3dWeightTemplate):
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.cuda() self.bias = self.bias.cuda()
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
@CONV3D_WEIGHT_REGISTER("Defaultt-Force-BF16") @CONV3D_WEIGHT_REGISTER("Defaultt-Force-BF16")
class Conv3dWeightForceBF16(Conv3dWeight): class Conv3dWeightForceBF16(Conv3dWeight):
......
...@@ -4,6 +4,7 @@ from vllm import _custom_ops as ops ...@@ -4,6 +4,7 @@ from vllm import _custom_ops as ops
import sgl_kernel import sgl_kernel
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
from lightx2v.utils.envs import *
from loguru import logger from loguru import logger
try: try:
...@@ -31,9 +32,8 @@ class MMWeightTemplate(metaclass=ABCMeta): ...@@ -31,9 +32,8 @@ class MMWeightTemplate(metaclass=ABCMeta):
def apply(self, input_tensor): def apply(self, input_tensor):
pass pass
def set_config(self, config=None): def set_config(self, config={}):
if config is not None: self.config = config
self.config = config
def to_cpu(self, non_blocking=False): def to_cpu(self, non_blocking=False):
self.weight = self.weight.to("cpu", non_blocking=non_blocking) self.weight = self.weight.to("cpu", non_blocking=non_blocking)
...@@ -49,6 +49,14 @@ class MMWeightTemplate(metaclass=ABCMeta): ...@@ -49,6 +49,14 @@ class MMWeightTemplate(metaclass=ABCMeta):
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking) self.bias = self.bias.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
@MM_WEIGHT_REGISTER("Default") @MM_WEIGHT_REGISTER("Default")
class MMWeight(MMWeightTemplate): class MMWeight(MMWeightTemplate):
...@@ -56,8 +64,12 @@ class MMWeight(MMWeightTemplate): ...@@ -56,8 +64,12 @@ class MMWeight(MMWeightTemplate):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name)
def load(self, weight_dict): def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].t().cuda() if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.weight = weight_dict[self.weight_name].t().cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
else:
self.weight = weight_dict[self.weight_name].cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1]) shape = (input_tensor.shape[0], self.weight.shape[1])
...@@ -94,39 +106,43 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -94,39 +106,43 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def load(self, weight_dict): def load(self, weight_dict):
self.load_func(weight_dict) self.load_func(weight_dict)
if self.weight_need_transpose:
self.weight = self.weight.t()
def load_quantized(self, weight_dict): def load_quantized(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name].cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda() self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
def load_fp8_perchannel_sym(self, weight_dict): def load_fp8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", True): if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda() self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = FloatQuantizer("e4m3", True, "per_channel") w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn) self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32) self.weight_scale = self.weight_scale.to(torch.float32)
if self.weight_need_transpose:
self.weight = self.weight.t()
else: else:
self.load_quantized(weight_dict) self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def load_int8_perchannel_sym(self, weight_dict): def load_int8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", True): if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda() self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = IntegerQuantizer(8, True, "per_channel") w_quantizer = IntegerQuantizer(8, True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8) self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32) self.weight_scale = self.weight_scale.to(torch.float32)
if self.weight_need_transpose:
self.weight = self.weight.t()
else: else:
self.load_quantized(weight_dict) self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def load_fp8_perblock128_sym(self, weight_dict): def load_fp8_perblock128_sym(self, weight_dict):
if self.config.get("weight_auto_quant", True): if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name].cuda()
self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight) self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
if self.weight_need_transpose:
self.weight = self.weight.t()
else: else:
self.load_quantized(weight_dict) self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
...@@ -174,6 +190,16 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -174,6 +190,16 @@ class MMWeightQuantTemplate(MMWeightTemplate):
sgl_kernel.sgl_per_token_group_quant_fp8(x, input_tensor_quant, input_tensor_scale, group_size=128, eps=1e-10, fp8_min=-448.0, fp8_max=448.0) sgl_kernel.sgl_per_token_group_quant_fp8(x, input_tensor_quant, input_tensor_scale, group_size=128, eps=1e-10, fp8_min=-448.0, fp8_max=448.0)
return input_tensor_quant, input_tensor_scale return input_tensor_quant, input_tensor_scale
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
if hasattr(self, "weight_scale"):
destination[self.weight_name.rstrip(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone()
return destination
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm") @MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate): class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
...@@ -452,7 +478,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): ...@@ -452,7 +478,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
if __name__ == "__main__": if __name__ == "__main__":
weight_dict = { weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn), "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn).t(),
"xx.bias": torch.randn(8192).to(torch.bfloat16), "xx.bias": torch.randn(8192).to(torch.bfloat16),
"xx.weight_scale": torch.randn(8192, 1).to(torch.float32), "xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
} }
......
...@@ -34,6 +34,15 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -34,6 +34,15 @@ class LNWeightTemplate(metaclass=ABCMeta):
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking) self.bias = self.bias.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
if self.weight is not None:
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
@LN_WEIGHT_REGISTER("Default") @LN_WEIGHT_REGISTER("Default")
class LNWeight(LNWeightTemplate): class LNWeight(LNWeightTemplate):
......
...@@ -38,6 +38,12 @@ class RMSWeight(RMSWeightTemplate): ...@@ -38,6 +38,12 @@ class RMSWeight(RMSWeightTemplate):
input_tensor = input_tensor * self.weight input_tensor = input_tensor * self.weight
return input_tensor return input_tensor
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
return destination
@RMS_WEIGHT_REGISTER("FP32") @RMS_WEIGHT_REGISTER("FP32")
class RMSWeightFP32(RMSWeight): class RMSWeightFP32(RMSWeight):
......
...@@ -14,3 +14,9 @@ class DefaultTensor: ...@@ -14,3 +14,9 @@ class DefaultTensor:
def to_cuda(self, non_blocking=False): def to_cuda(self, non_blocking=False):
self.tensor = self.tensor.cuda(non_blocking=non_blocking) self.tensor = self.tensor.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.tensor_name] = self.tensor.cpu().detach().clone()
return destination
import os import os
import sys
import torch import torch
from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights
from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights
...@@ -10,6 +11,8 @@ from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer im ...@@ -10,6 +11,8 @@ from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer im
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import *
from loguru import logger
class HunyuanModel: class HunyuanModel:
...@@ -24,6 +27,11 @@ class HunyuanModel: ...@@ -24,6 +27,11 @@ class HunyuanModel:
self.args = args self.args = args
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
if GET_RUNNING_FLAG() == "save_naive_quant":
assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None"
self.save_weights(self.config.naive_quant_path)
sys.exit(0)
self._init_infer() self._init_infer()
if config["parallel_attn_type"]: if config["parallel_attn_type"]:
...@@ -57,8 +65,18 @@ class HunyuanModel: ...@@ -57,8 +65,18 @@ class HunyuanModel:
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"] weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"]
return weight_dict return weight_dict
def _load_ckpt_quant_model(self):
assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None"
logger.info(f"Loading quant model from {self.config.naive_quant_path}")
quant_weights_path = os.path.join(self.config.naive_quant_path, "quant_weights.pth")
weight_dict = torch.load(quant_weights_path, map_location=self.device, weights_only=True)
return weight_dict
def _init_weights(self): def _init_weights(self):
weight_dict = self._load_ckpt() if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False):
weight_dict = self._load_ckpt()
else:
weight_dict = self._load_ckpt_quant_model()
# init weights # init weights
self.pre_weight = self.pre_weight_class(self.config) self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config) self.post_weight = self.post_weight_class(self.config)
...@@ -73,6 +91,28 @@ class HunyuanModel: ...@@ -73,6 +91,28 @@ class HunyuanModel:
self.post_infer = self.post_infer_class(self.config) self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config) self.transformer_infer = self.transformer_infer_class(self.config)
def save_weights(self, save_path):
if not os.path.exists(save_path):
os.makedirs(save_path)
pre_state_dict = self.pre_weight.state_dict()
logger.info(pre_state_dict.keys())
post_state_dict = self.post_weight.state_dict()
logger.info(post_state_dict.keys())
transformer_state_dict = self.transformer_weights.state_dict()
logger.info(transformer_state_dict.keys())
save_dict = {}
save_dict.update(pre_state_dict)
save_dict.update(post_state_dict)
save_dict.update(transformer_state_dict)
save_path = os.path.join(save_path, "quant_weights.pth")
torch.save(save_dict, save_path)
logger.info(f"Save weights to {save_path}")
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler) self.pre_infer.set_scheduler(scheduler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment