Commit f21528e7 authored by gushiqiao's avatar gushiqiao Committed by Yang Yong(雍洋)
Browse files

Support q8f kernel and fix bugs. (#6)


Co-authored-by: default avatargushiqiao <gushiqiao@sensetime.com>
parent bd0f840f
......@@ -21,7 +21,7 @@ docker run --gpus all -itd --ipc=host --name [name] -v /mnt:/mnt --entrypoint /b
```
git clone https://gitlab.bj.sensetime.com/video-gen/lightx2v.git
cd lightx2v
cd lightx2v/scripts
# Modify the parameters of the running script
bash run_hunyuan_t2v.sh
......
......@@ -3,6 +3,10 @@ from abc import ABCMeta, abstractmethod
from vllm import _custom_ops as ops
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
try:
import q8_kernels.functional as Q8F
except ImportError:
Q8F = None
class MMWeightTemplate(metaclass=ABCMeta):
......@@ -113,7 +117,7 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
@MM_WEIGHT_REGISTER('W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm')
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate):
'''
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
......@@ -159,6 +163,92 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
self.bias = self.bias.cuda()
@MM_WEIGHT_REGISTER('W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F')
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate):
'''
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Q8F
'''
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def load(self, weight_dict):
if self.config.get('weight_auto_quant', True):
self.weight = weight_dict[self.weight_name].cuda()
w_quantizer = IntegerQuantizer(8, True, 'channel')
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32)
else:
self.weight = weight_dict[self.weight_name].cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + '.weight_scale'].cuda()
self.bias = weight_dict[self.bias_name].float().cuda() if self.bias_name is not None else None
def apply(self, input_tensor, act=None):
qinput, x_scale, _ = ops.scaled_int8_quant(input_tensor, scale=None, azp=None, symmetric=True)
output_tensor = Q8F.linear.q8_linear(qinput, self.weight, self.bias, x_scale, self.weight_scale, fuse_gelu=False, out_dtype=torch.bfloat16)
return output_tensor.squeeze(0)
def to_cpu(self):
self.weight = self.weight.cpu()
self.weight_scale = self.weight_scale.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
self.weight_scale = self.weight_scale.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
@MM_WEIGHT_REGISTER('W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F')
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate):
'''
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Q8F
'''
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def load(self, weight_dict):
if self.config.get('weight_auto_quant', True):
self.weight = weight_dict[self.weight_name].cuda()
w_quantizer = FloatQuantizer('e4m3', True, 'channel')
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32)
else:
self.weight = weight_dict[self.weight_name].cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + '.weight_scale'].cuda()
self.bias = weight_dict[self.bias_name].float().cuda() if self.bias_name is not None else None
def apply(self, input_tensor):
qinput, x_scale = ops.scaled_fp8_quant(input_tensor, None, scale_ub=None, use_per_token_if_dynamic=True)
output_tensor = Q8F.linear.fp8_linear(qinput, self.weight, self.bias, x_scale, self.weight_scale, out_dtype=torch.bfloat16)
return output_tensor.squeeze(0)
def to_cpu(self):
self.weight = self.weight.cpu()
self.weight_scale = self.weight_scale.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
self.weight_scale = self.weight_scale.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
if __name__ == '__main__':
weight_dict = {
'xx.weight': torch.randn(8192, 4096).to(torch.float8_e4m3fn),
......
import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb, rms_norm
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb
from lightx2v.attentions import attention
......@@ -60,15 +60,8 @@ class WanTransformerInfer:
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
q = rms_norm(
weights.self_attn_q.apply(norm1_out), weights.self_attn_norm_q_weight, 1e-6
).view(s, n, d)
k = rms_norm(
weights.self_attn_k.apply(norm1_out), weights.self_attn_norm_k_weight, 1e-6
).view(s, n, d)
q = weights.self_attn_norm_q.apply(weights.self_attn_q.apply(norm1_out)).view(s, n, d)
k = weights.self_attn_norm_k.apply(weights.self_attn_k.apply(norm1_out)).view(s, n, d)
v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
if not self.parallel_attention:
......@@ -114,21 +107,12 @@ class WanTransformerInfer:
context = context[257:]
n, d = self.num_heads, self.head_dim
q = rms_norm(
weights.cross_attn_q.apply(norm3_out), weights.cross_attn_norm_q_weight, 1e-6
).view(-1, n, d)
k = rms_norm(
weights.cross_attn_k.apply(context), weights.cross_attn_norm_k_weight, 1e-6
).view(-1, n, d)
q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d)
k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d)
v = weights.cross_attn_v.apply(context).view(-1, n, d)
if self.task == 'i2v':
k_img = rms_norm(
weights.cross_attn_k_img.apply(context_img), weights.cross_attn_norm_k_img_weight, 1e-6
).view(-1, n, d)
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(
......
......@@ -4,14 +4,6 @@ import torch.cuda.amp as amp
import torch.distributed as dist
def rms_norm(x, weight, eps):
x = x.contiguous()
orig_shape = x.shape
x = x.view(-1, orig_shape[-1])
x = sgl_kernel.rmsnorm(x, weight, eps).view(orig_shape)
return x
def compute_freqs(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
......
......@@ -42,16 +42,16 @@ class WanPreWeights:
self.weight_list.append(self.proj_4)
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.set_config(self.config['mm_config'])
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.to_cuda()
\ No newline at end of file
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.common.ops.norm.rms_norm_weight import RMSWeightTemplate
class WanTransformerWeights:
......@@ -42,15 +43,17 @@ class WanTransformerAttentionBlock:
self.self_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.k.weight',f'blocks.{self.block_index}.self_attn.k.bias')
self.self_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.v.weight',f'blocks.{self.block_index}.self_attn.v.bias')
self.self_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.o.weight',f'blocks.{self.block_index}.self_attn.o.bias')
self.self_attn_norm_q_weight = weight_dict[f'blocks.{self.block_index}.self_attn.norm_q.weight']
self.self_attn_norm_k_weight = weight_dict[f'blocks.{self.block_index}.self_attn.norm_k.weight']
self.norm3 = LN_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.norm3.weight',f'blocks.{self.block_index}.norm3.bias',eps = 1e-6)
self.self_attn_norm_q = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.self_attn.norm_q.weight')
self.self_attn_norm_k = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.self_attn.norm_k.weight')
self.norm3 = LN_WEIGHT_REGISTER['Default'](f'blocks.{self.block_index}.norm3.weight',f'blocks.{self.block_index}.norm3.bias',eps = 1e-6)
self.cross_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.q.weight',f'blocks.{self.block_index}.cross_attn.q.bias')
self.cross_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.k.weight',f'blocks.{self.block_index}.cross_attn.k.bias')
self.cross_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.v.weight',f'blocks.{self.block_index}.cross_attn.v.bias')
self.cross_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.o.weight',f'blocks.{self.block_index}.cross_attn.o.bias')
self.cross_attn_norm_q_weight = weight_dict[f'blocks.{self.block_index}.cross_attn.norm_q.weight']
self.cross_attn_norm_k_weight = weight_dict[f'blocks.{self.block_index}.cross_attn.norm_k.weight']
self.cross_attn_norm_q = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.cross_attn.norm_q.weight')
self.cross_attn_norm_k = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.cross_attn.norm_k.weight')
self.ffn_0 = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.ffn.0.weight',f'blocks.{self.block_index}.ffn.0.bias')
self.ffn_2 = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.ffn.2.weight',f'blocks.{self.block_index}.ffn.2.bias')
self.modulation = weight_dict[f'blocks.{self.block_index}.modulation']
......@@ -60,15 +63,15 @@ class WanTransformerAttentionBlock:
self.self_attn_k,
self.self_attn_v,
self.self_attn_o,
self.self_attn_norm_q_weight,
self.self_attn_norm_k_weight,
self.self_attn_norm_q,
self.self_attn_norm_k,
self.norm3,
self.cross_attn_q,
self.cross_attn_k,
self.cross_attn_v,
self.cross_attn_o,
self.cross_attn_norm_q_weight,
self.cross_attn_norm_k_weight,
self.cross_attn_norm_q,
self.cross_attn_norm_k,
self.ffn_0,
self.ffn_2,
self.modulation,
......@@ -77,26 +80,27 @@ class WanTransformerAttentionBlock:
if self.task == 'i2v':
self.cross_attn_k_img = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.k_img.weight',f'blocks.{self.block_index}.cross_attn.k_img.bias')
self.cross_attn_v_img = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.v_img.weight',f'blocks.{self.block_index}.cross_attn.v_img.bias')
self.cross_attn_norm_k_img_weight = weight_dict[f'blocks.{self.block_index}.cross_attn.norm_k_img.weight']
# self.cross_attn_norm_k_img_weight = weight_dict[f'blocks.{self.block_index}.cross_attn.norm_k_img.weight']
self.cross_attn_norm_k_img = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.cross_attn.norm_k_img.weight')
self.weight_list.append(self.cross_attn_k_img)
self.weight_list.append(self.cross_attn_v_img)
self.weight_list.append(self.cross_attn_norm_k_img_weight)
self.weight_list.append(self.cross_attn_norm_k_img)
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate):
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config['mm_config'])
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate):
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu()
else:
mm_weight.cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate):
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda()
else:
mm_weight.cuda()
\ No newline at end of file
......@@ -16,41 +16,58 @@ class WanSchedulerFeatureCaching(WanScheduler):
self.previous_residual_odd = None
self.use_ret_steps = self.args.use_ret_steps
if self.use_ret_steps:
if self.args.target_width == 480 or self.args.target_height == 480:
self.coefficients = [
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
]
if self.args.target_width == 720 or self.args.target_height == 720:
self.coefficients = [
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
]
self.ret_steps = 5 * 2
self.cutoff_steps = self.args.infer_steps * 2
else:
if self.args.target_width == 480 or self.args.target_height == 480:
self.coefficients = [
-3.02331670e02,
2.23948934e02,
-5.25463970e01,
5.87348440e00,
-2.01973289e-01,
]
if self.args.target_width == 720 or self.args.target_height == 720:
self.coefficients = [
-114.36346466,
65.26524496,
-18.82220707,
4.91518089,
-0.23412683,
]
self.ret_steps = 1 * 2
self.cutoff_steps = self.args.infer_steps * 2 - 2
\ No newline at end of file
if self.args.task == 'i2v':
if self.use_ret_steps:
if self.args.target_width == 480 or self.args.target_height == 480:
self.coefficients = [
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
]
if self.args.target_width == 720 or self.args.target_height == 720:
self.coefficients = [
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
]
self.ret_steps = 5 * 2
self.cutoff_steps = self.args.infer_steps * 2
else:
if self.args.target_width == 480 or self.args.target_height == 480:
self.coefficients = [
-3.02331670e02,
2.23948934e02,
-5.25463970e01,
5.87348440e00,
-2.01973289e-01,
]
if self.args.target_width == 720 or self.args.target_height == 720:
self.coefficients = [
-114.36346466,
65.26524496,
-18.82220707,
4.91518089,
-0.23412683,
]
self.ret_steps = 1 * 2
self.cutoff_steps = self.args.infer_steps * 2 - 2
elif self.args.task == 't2v':
if self.use_ret_steps:
if '1.3B' in self.args.model_path:
self.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
if '14B' in self.args.model_path:
self.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
self.ret_steps = 5 * 2
self.cutoff_steps = self.args.infer_steps * 2
else:
if '1.3B' in self.args.model_path:
self.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
if '14B' in self.args.model_path:
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
self.ret_steps = 1 * 2
self.cutoff_steps = self.args.infer_steps * 2 - 2
\ No newline at end of file
......@@ -4,7 +4,7 @@ model_path=/workspace/ckpts_link # H800-14
export CUDA_VISIBLE_DEVICES=0,1,2,3
torchrun --nproc_per_node=4 main.py \
torchrun --nproc_per_node=4 ../main.py \
--model_cls hunyuan \
--model_path $model_path \
--prompt "A cat walks on the grass, realistic style." \
......
#!/bin/bash
# model_path=/mnt/nvme1/yongyang/models/hy/ckpts # H800-13
# model_path=/workspace/wan/Wan2.1-T2V-1.3B # H800-14
# config_path=/workspace/wan/Wan2.1-T2V-1.3B/config.json
model_path=/mnt/nvme0/yongyang/projects/wan/Wan2.1-T2V-1.3B # H800-14
config_path=/mnt/nvme0/yongyang/projects/wan/Wan2.1-T2V-1.3B/config.json
model_path=/workspace/wan/Wan2.1-T2V-1.3B # H800-14
config_path=/workspace/wan/Wan2.1-T2V-1.3B/config.json
export CUDA_VISIBLE_DEVICES=0
python main.py \
python ../main.py \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
......@@ -20,7 +17,10 @@ python main.py \
--seed 42 \
--sample_neg_promp 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--config_path $config_path \
--save_video_path ./output_lightx2v_seed42.mp4 \
--save_video_path ./output_lightx2v_seed42_q8f1_teacache.mp4 \
--sample_guide_scale 6 \
--sample_shift 8
# --mm_config '{"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
\ No newline at end of file
--sample_shift 8 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F", "weight_auto_quant": true}' \
# --feature_caching Tea \
# --use_ret_steps \
# --teacache_thresh 0.2
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment