"llama/llama.cpp/include/llama.h" did not exist on "18ffeeec4518fd8345bb113c0099a3ff269da996"
Commit ad051778 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix (#80)



* Fix

* Fix

* Fix

---------
Co-authored-by: default avatargushiqiao <gushiqiao@sensetime.com>
parent fb69083e
......@@ -24,7 +24,7 @@ class WanDistillModel(WanModel):
ckpt_path = os.path.join(self.model_path, "distill_model.pt")
if not os.path.exists(ckpt_path):
# 文件不存在,调用父类的 _load_ckpt 方法
return super()._load_ckpt()
return super()._load_ckpt(use_bf16, skip_bf16)
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()}
......
......@@ -8,6 +8,7 @@ class WanPostInfer:
def __init__(self, config):
self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2)
self.clean_cuda_cache = config.get("clean_cuda_cache", False)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
......@@ -21,16 +22,21 @@ class WanPostInfer:
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
norm_out = weights.norm.apply(x)
x = weights.norm.apply(x)
if GET_DTYPE() != "BF16":
norm_out = norm_out.float()
out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0)
x = x.float()
x.mul_(1 + e[1].squeeze(0)).add_(e[0].squeeze(0))
if GET_DTYPE() != "BF16":
out = out.to(torch.bfloat16)
x = x.to(torch.bfloat16)
x = weights.head.apply(out)
x = weights.head.apply(x)
x = self.unpatchify(x, grid_sizes)
if self.clean_cuda_cache:
del e, grid_sizes
torch.cuda.empty_cache()
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
......
......@@ -7,7 +7,7 @@ class WanPreInfer:
def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
d = config["dim"] // config["num_heads"]
self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.task = config["task"]
self.freqs = torch.cat(
[
......@@ -87,6 +87,9 @@ class WanPreInfer:
out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out)
if self.clean_cuda_cache:
del out, stacked
torch.cuda.empty_cache()
if self.task == "i2v":
context_clip = weights.proj_0.apply(clip_fea)
......@@ -95,7 +98,9 @@ class WanPreInfer:
context_clip = weights.proj_3.apply(context_clip)
context_clip = weights.proj_4.apply(context_clip)
context = torch.concat([context_clip, context], dim=0)
if self.clean_cuda_cache:
del context_clip, clip_fea
torch.cuda.empty_cache()
return (
embed,
grid_sizes,
......
import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb, apply_rotary_emb_chunk
from lightx2v.common.offload.manager import (
WeightAsyncStreamManager,
LazyWeightAsyncStreamManager,
......@@ -14,11 +14,13 @@ class WanTransformerInfer(BaseTransformerInfer):
self.task = config["task"]
self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config["num_layers"]
self.phases_num = 3
self.phases_num = 4
self.num_heads = config["num_heads"]
self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None
self.apply_rotary_emb_func = apply_rotary_emb_chunk if config.get("rotary_chunk", False) else apply_rotary_emb
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
if self.config["cpu_offload"]:
if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"]
......@@ -92,10 +94,6 @@ class WanTransformerInfer(BaseTransformerInfer):
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(weights.blocks_num):
weights.blocks[block_idx].modulation.to_cuda()
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context)
for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0:
phase = weights.blocks[block_idx].compute_phases[phase_idx]
......@@ -105,12 +103,23 @@ class WanTransformerInfer(BaseTransformerInfer):
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
cur_phase_idx, cur_phase = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
y_out = self.infer_phase_2(cur_phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(cur_phase, embed0)
elif cur_phase_idx == 1:
attn_out = self.infer_phase_3(cur_phase, x, context, y_out, gate_msa)
y_out = self.infer_self_attn(
cur_phase,
grid_sizes,
x,
seq_lens,
freqs,
shift_msa,
scale_msa,
)
elif cur_phase_idx == 2:
y = self.infer_phase_4(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.infer_phase_5(x, y, c_gate_msa)
attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase:
......@@ -120,8 +129,6 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.swap_phases()
weights.blocks[block_idx].modulation.to_cpu()
torch.cuda.empty_cache()
return x
......@@ -130,11 +137,6 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.prefetch_weights_from_disk(weights)
for block_idx in range(weights.blocks_num):
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
weights.blocks[block_idx].modulation.to_cuda()
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context)
for phase_idx in range(self.weights_stream_mgr.phases_num):
if block_idx == 0 and phase_idx == 0:
obj_key = (block_idx, phase_idx)
......@@ -152,12 +154,25 @@ class WanTransformerInfer(BaseTransformerInfer):
) = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
y_out = self.infer_phase_2(cur_phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(
cur_phase,
embed0,
)
elif cur_phase_idx == 1:
attn_out = self.infer_phase_3(cur_phase, x, context, y_out, gate_msa)
y_out = self.infer_self_attn(
cur_phase,
grid_sizes,
x,
seq_lens,
freqs,
shift_msa,
scale_msa,
)
elif cur_phase_idx == 2:
y = self.infer_phase_4(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.infer_phase_5(x, y, c_gate_msa)
attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
......@@ -166,10 +181,16 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.swap_phases()
weights.blocks[block_idx].modulation.to_cpu()
self.weights_stream_mgr._async_prefetch_block(weights)
torch.cuda.empty_cache()
if self.clean_cuda_cache:
del attn_out, y_out, y
torch.cuda.empty_cache()
if self.clean_cuda_cache:
del shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache()
return x
......@@ -188,36 +209,51 @@ class WanTransformerInfer(BaseTransformerInfer):
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
y_out = self.infer_phase_2(weights.compute_phases[0], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa)
attn_out = self.infer_phase_3(weights.compute_phases[1], x, context, y_out, gate_msa)
y = self.infer_phase_4(weights.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
x = self.infer_phase_5(x, y, c_gate_msa)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(
weights.compute_phases[0],
embed0,
)
y_out = self.infer_self_attn(
weights.compute_phases[1],
grid_sizes,
x,
seq_lens,
freqs,
shift_msa,
scale_msa,
)
attn_out = self.infer_cross_attn(weights.compute_phases[2], x, context, y_out, gate_msa)
y = self.infer_ffn(weights.compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
return x
def infer_phase_1(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
def infer_modulation(self, weights, embed0):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2)
embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.modulation.tensor + embed0).chunk(6, dim=1)
if self.clean_cuda_cache:
del embed0
torch.cuda.empty_cache()
return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
def infer_phase_2(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
def infer_self_attn(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
if hasattr(weights, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa) * weights.smooth_norm1_weight.tensor
norm1_bias = shift_msa * weights.smooth_norm1_bias.tensor
norm1_weight = (1 + scale_msa.squeeze(0)) * weights.smooth_norm1_weight.tensor
norm1_bias = shift_msa.squeeze(0) * weights.smooth_norm1_bias.tensor
else:
norm1_weight = 1 + scale_msa
norm1_bias = shift_msa
norm1_weight = 1 + scale_msa.squeeze(0)
norm1_bias = shift_msa.squeeze(0)
norm1_out = weights.norm1.apply(x)
if GET_DTYPE() != "BF16":
norm1_out = norm1_out.float()
norm1_out = (norm1_out * norm1_weight + norm1_bias).squeeze(0)
norm1_out.mul_(norm1_weight).add_(norm1_bias)
if GET_DTYPE() != "BF16":
norm1_out = norm1_out.to(torch.bfloat16)
......@@ -233,8 +269,8 @@ class WanTransformerInfer(BaseTransformerInfer):
else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
q = apply_rotary_emb(q, freqs_i)
k = apply_rotary_emb(k, freqs_i)
q = self.apply_rotary_emb_func(q, freqs_i)
k = self.apply_rotary_emb_func(k, freqs_i)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=seq_lens)
......@@ -260,9 +296,14 @@ class WanTransformerInfer(BaseTransformerInfer):
)
y = weights.self_attn_o.apply(attn_out)
if self.clean_cuda_cache:
del q, k, v, attn_out, freqs_i, norm1_out, norm1_weight, norm1_bias
torch.cuda.empty_cache()
return y
def infer_phase_3(self, weights, x, context, y_out, gate_msa):
def infer_cross_attn(self, weights, x, context, y_out, gate_msa):
if GET_DTYPE() != "BF16":
x = x.float() + y_out.float() * gate_msa.squeeze(0)
else:
......@@ -319,13 +360,26 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_kv=k_img.size(0),
model_cls=self.config["model_cls"],
)
attn_out = attn_out + img_attn_out
attn_out.add_(img_attn_out)
if self.clean_cuda_cache:
del k_img, v_img, img_attn_out
torch.cuda.empty_cache()
attn_out = weights.cross_attn_o.apply(attn_out)
if self.clean_cuda_cache:
del q, k, v, norm3_out, context, context_img
torch.cuda.empty_cache()
return attn_out
def infer_phase_4(self, weights, x, attn_out, c_shift_msa, c_scale_msa):
def infer_ffn(self, weights, x, attn_out, c_shift_msa, c_scale_msa):
x.add_(attn_out)
if self.clean_cuda_cache:
del attn_out
torch.cuda.empty_cache()
if hasattr(weights, "smooth_norm2_weight"):
norm2_weight = (1 + c_scale_msa.squeeze(0)) * weights.smooth_norm2_weight.tensor
norm2_bias = c_shift_msa.squeeze(0) * weights.smooth_norm2_bias.tensor
......@@ -333,21 +387,30 @@ class WanTransformerInfer(BaseTransformerInfer):
norm2_weight = 1 + c_scale_msa.squeeze(0)
norm2_bias = c_shift_msa.squeeze(0)
norm2_out = weights.norm2.apply(x)
x = weights.norm2.apply(x)
if GET_DTYPE() != "BF16":
norm2_out = norm2_out.float()
norm2_out = norm2_out * norm2_weight + norm2_bias
x = x.float()
x.mul_(norm2_weight).add_(norm2_bias)
if GET_DTYPE() != "BF16":
norm2_out = norm2_out.to(torch.bfloat16)
x = x.to(torch.bfloat16)
y = weights.ffn_0.apply(norm2_out)
y = torch.nn.functional.gelu(y, approximate="tanh")
y = weights.ffn_2.apply(y)
return y
x = weights.ffn_0.apply(x)
if self.clean_cuda_cache:
torch.cuda.empty_cache()
x = torch.nn.functional.gelu(x, approximate="tanh")
if self.clean_cuda_cache:
torch.cuda.empty_cache()
x = weights.ffn_2.apply(x)
def infer_phase_5(self, x, y, c_gate_msa):
return x
def post_process(self, x, y, c_gate_msa):
if GET_DTYPE() != "BF16":
x = x.float() + y.float() * c_gate_msa.squeeze(0)
else:
x.add_(y * c_gate_msa.squeeze(0))
if self.clean_cuda_cache:
del y, c_gate_msa
torch.cuda.empty_cache()
return x
......@@ -75,6 +75,39 @@ def apply_rotary_emb(x, freqs_i):
return x_i.to(torch.bfloat16)
def apply_rotary_emb_chunk(x, freqs_i, chunk_size=100, remaining_chunk_size=100):
n = x.size(1)
seq_len = freqs_i.size(0)
output_chunks = []
for start in range(0, seq_len, chunk_size):
end = min(start + chunk_size, seq_len)
x_chunk = x[start:end]
freqs_chunk = freqs_i[start:end]
x_chunk_complex = torch.view_as_complex(x_chunk.to(torch.float32).reshape(end - start, n, -1, 2))
x_chunk_embedded = torch.view_as_real(x_chunk_complex * freqs_chunk).flatten(2).to(torch.bfloat16)
output_chunks.append(x_chunk_embedded)
del x_chunk_complex, x_chunk_embedded
torch.cuda.empty_cache()
result = []
for chunk in output_chunks:
result.append(chunk)
del output_chunks
torch.cuda.empty_cache()
for start in range(seq_len, x.size(0), remaining_chunk_size):
end = min(start + remaining_chunk_size, x.size(0))
result.append(x[start:end])
x_i = torch.cat(result, dim=0)
del result
torch.cuda.empty_cache()
return x_i.to(torch.bfloat16)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
......
......@@ -34,6 +34,7 @@ class WanModel:
def __init__(self, model_path, config, device):
self.model_path = model_path
self.config = config
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", None)
......@@ -133,22 +134,7 @@ class WanModel:
else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
safetensors_pattern = os.path.join(lazy_load_model_path, "block_*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {lazy_load_model_path}")
for file_path in safetensors_files:
with safe_open(file_path, framework="pt") as f:
for k in f.keys():
if "modulation" in k:
if f.get_tensor(k).dtype == torch.float:
if use_bf16 or all(s not in k for s in skip_bf16):
transformer_weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
else:
transformer_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
return pre_post_weight_dict, transformer_weight_dict
return pre_post_weight_dict
def _init_weights(self, weight_dict=None):
use_bf16 = GET_DTYPE() == "BF16"
......@@ -161,10 +147,7 @@ class WanModel:
if not self.config.get("lazy_load", False):
self.original_weight_dict = self._load_quant_ckpt(use_bf16, skip_bf16)
else:
(
self.original_weight_dict,
self.transformer_weight_dict,
) = self._load_quant_split_ckpt(use_bf16, skip_bf16)
self.original_weight_dict = self._load_quant_split_ckpt(use_bf16, skip_bf16)
else:
self.original_weight_dict = weight_dict
# init weights
......@@ -174,10 +157,7 @@ class WanModel:
# load weights
self.pre_weight.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict)
if hasattr(self, "transformer_weight_dict"):
self.transformer_weights.load(self.transformer_weight_dict)
else:
self.transformer_weights.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
......@@ -212,13 +192,21 @@ class WanModel:
self.scheduler.noise_pred = noise_pred_cond
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
torch.cuda.empty_cache()
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
torch.cuda.empty_cache()
......@@ -34,13 +34,8 @@ class WanTransformerAttentionBlock(WeightModule):
self.config = config
self.quant_method = config["mm_config"].get("quant_method", None)
self.sparge = config.get("sparge", False)
self.register_parameter(
"modulation",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.modulation"),
)
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
lazy_load_path = os.path.join(self.config.dit_quantized_ckpt, f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
......@@ -49,6 +44,14 @@ class WanTransformerAttentionBlock(WeightModule):
self.compute_phases = WeightModuleList(
[
WanModulation(
block_index,
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
WanSelfAttention(
block_index,
task,
......@@ -79,6 +82,29 @@ class WanTransformerAttentionBlock(WeightModule):
self.add_module("compute_phases", self.compute_phases)
class WanModulation(WeightModule):
def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.quant_method = config["mm_config"].get("quant_method", None)
self.sparge = config.get("sparge", False)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module(
"modulation",
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.modulation",
self.lazy_load,
self.lazy_load_file,
),
)
class WanSelfAttention(WeightModule):
def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
......@@ -92,7 +118,7 @@ class WanSelfAttention(WeightModule):
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.register_parameter(
self.add_module(
"norm1",
LN_WEIGHT_REGISTER["Default"](),
)
......@@ -160,7 +186,7 @@ class WanSelfAttention(WeightModule):
else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
if self.quant_method in ["smoothquant", "awq"]:
self.register_parameter(
self.add_module(
"smooth_norm1_weight",
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm1.weight",
......@@ -168,7 +194,7 @@ class WanSelfAttention(WeightModule):
self.lazy_load_file,
),
)
self.register_parameter(
self.add_module(
"smooth_norm1_bias",
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm1.bias",
......@@ -292,7 +318,7 @@ class WanFFN(WeightModule):
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.register_parameter(
self.add_module(
"norm2",
LN_WEIGHT_REGISTER["Default"](),
)
......@@ -317,7 +343,7 @@ class WanFFN(WeightModule):
)
if self.quant_method in ["smoothquant", "awq"]:
self.register_parameter(
self.add_module(
"smooth_norm2_weight",
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm3.weight",
......@@ -325,7 +351,7 @@ class WanFFN(WeightModule):
self.lazy_load_file,
),
)
self.register_parameter(
self.add_module(
"smooth_norm2_bias",
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm3.bias",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment