Commit 165ec807 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Support load advance ptq model. (#33)



* Support load advance ptq model.

* Update run_wan_i2v_advanced_ptq.sh

---------
Co-authored-by: default avatargushiqiao <gushiqiao@sensetime.com>
Co-authored-by: default avatarYang Yong(雍洋) <yongyang1030@163.com>
parent f2a3c894
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"quant_method": "smoothquant"
},
"naive_quant_path": "/path/to/int8_model"
}
...@@ -107,7 +107,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -107,7 +107,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def load_quantized(self, weight_dict): def load_quantized(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name].cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda() self.weight_scale = weight_dict[self.weight_name.removesuffix(".weight") + ".weight_scale"].float().cuda()
def load_fp8_perchannel_sym(self, weight_dict): def load_fp8_perchannel_sym(self, weight_dict):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False): if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
...@@ -192,7 +192,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -192,7 +192,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if self.bias is not None: if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone() destination[self.bias_name] = self.bias.cpu().detach().clone()
if hasattr(self, "weight_scale"): if hasattr(self, "weight_scale"):
destination[self.weight_name.rstrip(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone() destination[self.weight_name.removesuffix(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone()
return destination return destination
......
...@@ -86,8 +86,15 @@ class WanTransformerInfer: ...@@ -86,8 +86,15 @@ class WanTransformerInfer:
elif embed0.dim() == 2: elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1) embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
if hasattr(weights, "smooth_norm1_weight"):
norm1_weight = (1 + embed0[1]) * weights.smooth_norm1_weight.tensor
norm1_bias = embed0[0] * weights.smooth_norm1_bias.tensor
else:
norm1_weight = 1 + embed0[1]
norm1_bias = embed0[0]
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0) norm1_out = (norm1_out * norm1_weight + norm1_bias).squeeze(0)
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
q = weights.self_attn_norm_q.apply(weights.self_attn_q.apply(norm1_out)).view(s, n, d) q = weights.self_attn_norm_q.apply(weights.self_attn_q.apply(norm1_out)).view(s, n, d)
...@@ -105,7 +112,16 @@ class WanTransformerInfer: ...@@ -105,7 +112,16 @@ class WanTransformerInfer:
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=seq_lens) cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=seq_lens)
if not self.parallel_attention: if not self.parallel_attention:
attn_out = weights.self_attn_1.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]) attn_out = weights.self_attn_1.apply(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
model_cls=self.config["model_cls"],
)
else: else:
attn_out = self.parallel_attention( attn_out = self.parallel_attention(
attention_type=self.attention_type, attention_type=self.attention_type,
...@@ -134,7 +150,16 @@ class WanTransformerInfer: ...@@ -134,7 +150,16 @@ class WanTransformerInfer:
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)) cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))
attn_out = weights.cross_attn_1.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]) attn_out = weights.cross_attn_1.apply(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
model_cls=self.config["model_cls"],
)
if self.task == "i2v": if self.task == "i2v":
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d) k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
...@@ -147,7 +172,14 @@ class WanTransformerInfer: ...@@ -147,7 +172,14 @@ class WanTransformerInfer:
) )
img_attn_out = weights.cross_attn_2.apply( img_attn_out = weights.cross_attn_2.apply(
attention_type=self.attention_type, q=q, k=k_img, v=v_img, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"] q=q,
k=k_img,
v=v_img,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
model_cls=self.config["model_cls"],
) )
attn_out = attn_out + img_attn_out attn_out = attn_out + img_attn_out
...@@ -155,8 +187,17 @@ class WanTransformerInfer: ...@@ -155,8 +187,17 @@ class WanTransformerInfer:
attn_out = weights.cross_attn_o.apply(attn_out) attn_out = weights.cross_attn_o.apply(attn_out)
x = x + attn_out x = x + attn_out
if hasattr(weights, "smooth_norm2_weight"):
norm2_weight = (1 + embed0[4].squeeze(0)) * weights.smooth_norm2_weight.tensor
norm2_bias = embed0[3].squeeze(0) * weights.smooth_norm2_bias.tensor
else:
norm2_weight = 1 + embed0[4].squeeze(0)
norm2_bias = embed0[3].squeeze(0)
norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0)) y = weights.ffn_0.apply(norm2_out * norm2_weight + norm2_bias)
y = torch.nn.functional.gelu(y, approximate="tanh") y = torch.nn.functional.gelu(y, approximate="tanh")
y = weights.ffn_2.apply(y) y = weights.ffn_2.apply(y)
x = x + y * embed0[5].squeeze(0) x = x + y * embed0[5].squeeze(0)
......
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import sys import sys
import torch import torch
import glob import glob
import json
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import ( from lightx2v.models.networks.wan.weights.transformer_weights import (
...@@ -83,7 +84,7 @@ class WanModel: ...@@ -83,7 +84,7 @@ class WanModel:
weight_dict.update(file_weights) weight_dict.update(file_weights)
return weight_dict return weight_dict
def _load_ckpt_quant_model(self): def _load_quant_ckpt(self):
assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None" assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None"
ckpt_path = self.config.naive_quant_path ckpt_path = self.config.naive_quant_path
logger.info(f"Loading quant model from {ckpt_path}") logger.info(f"Loading quant model from {ckpt_path}")
...@@ -107,9 +108,12 @@ class WanModel: ...@@ -107,9 +108,12 @@ class WanModel:
weight_dict = {} weight_dict = {}
for filename in set(index_data["weight_map"].values()): for filename in set(index_data["weight_map"].values()):
safetensor_path = os.path.join(ckpt_path, filename) safetensor_path = os.path.join(ckpt_path, filename)
logger.info(f"Loading weights from {safetensor_path}") with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
partial_weights = load_file(safetensor_path, device=self.device) logger.info(f"Loading weights from {safetensor_path}")
weight_dict.update(partial_weights) for k in f.keys():
weight_dict[k] = f.get_tensor(k)
if weight_dict[k].dtype == torch.float:
weight_dict[k] = weight_dict[k].to(torch.bfloat16)
return weight_dict return weight_dict
...@@ -118,7 +122,7 @@ class WanModel: ...@@ -118,7 +122,7 @@ class WanModel:
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False) or self.config["mm_config"].get("mm_type", "Default") == "Default": if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False) or self.config["mm_config"].get("mm_type", "Default") == "Default":
self.original_weight_dict = self._load_ckpt() self.original_weight_dict = self._load_ckpt()
else: else:
self.original_weight_dict = self._load_ckpt_quant_model() self.original_weight_dict = self._load_quant_ckpt()
else: else:
self.original_weight_dict = weight_dict self.original_weight_dict = weight_dict
# init weights # init weights
......
...@@ -24,6 +24,7 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -24,6 +24,7 @@ class WanTransformerAttentionBlock(WeightModule):
self.mm_type = mm_type self.mm_type = mm_type
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config["mm_config"].get("quant_method", None)
self.add_module("self_attn_q", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias")) self.add_module("self_attn_q", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias"))
self.add_module("self_attn_k", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.bias")) self.add_module("self_attn_k", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.bias"))
...@@ -68,4 +69,13 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -68,4 +69,13 @@ class WanTransformerAttentionBlock(WeightModule):
# do not load weights # do not load weights
pass pass
# For smoothquant or awq
if self.quant_method in ["smoothquant", "awq"]:
self.register_parameter("smooth_norm1_weight", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm1.weight"))
self.register_parameter("smooth_norm1_bias", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm1.bias"))
self.register_parameter("smooth_norm2_weight", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm3.weight"))
self.register_parameter("smooth_norm2_bias", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm3.bias"))
elif self.quant_method is not None:
raise NotImplementedError(f"This {self.quant_method} method is not implemented yet.")
self.register_parameter("modulation", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.modulation")) self.register_parameter("modulation", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.modulation"))
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/advanced_ptq/wan_i2v.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment