Commit c8fce5b9 authored by gushiqiao's avatar gushiqiao
Browse files

Fix bugs

parent 8bc0da34
File mode changed from 100644 to 100755
......@@ -11,7 +11,7 @@
"cpu_offload": true,
"offload_granularity": "block",
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F",
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": true
}
}
......@@ -12,7 +12,7 @@
"cpu_offload": true,
"offload_granularity": "block",
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F",
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": true
}
}
......@@ -32,5 +32,7 @@
[-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
],
"use_ret_steps": true,
"teacache_thresh": 0.26
"teacache_thresh": 0.26,
"rotary_chunk": true,
"clean_cuda_cache": true
}
......@@ -27,12 +27,12 @@ class WanCausVidModel(WanModel):
self.post_infer_class = WanPostInfer
self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self):
def _load_ckpt(self, use_bf16, skip_bf16):
use_bfloat16 = GET_DTYPE() == "BF16"
ckpt_path = os.path.join(self.model_path, "causal_model.pt")
if not os.path.exists(ckpt_path):
# 文件不存在,调用父类的 _load_ckpt 方法
return super()._load_ckpt()
return super()._load_ckpt(use_bf16, skip_bf16)
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
......
......@@ -91,7 +91,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
)
return x
def _infer_self_attn(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
def infer_self_attn(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
......@@ -135,7 +135,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
return x
def _infer_cross_attn(self, weights, x, context, block_idx):
def infer_cross_attn(self, weights, x, context, block_idx):
norm3_out = weights.norm3.apply(x)
if self.task == "i2v":
......@@ -195,7 +195,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
return x
def _infer_ffn(self, weights, x, embed0):
def infer_ffn(self, weights, x, embed0):
norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0))
y = torch.nn.functional.gelu(y, approximate="tanh")
......@@ -206,17 +206,15 @@ class WanTransformerInferCausVid(WanTransformerInfer):
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
modulation = weights.compute_phases[0].modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
embed0 = embed0.unsqueeze(0) #
embed0 = (modulation + embed0).chunk(6, dim=1)
embed0 = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
embed0 = (weights.compute_phases[0].modulation.tensor + embed0).chunk(6, dim=1)
x = self._infer_self_attn(weights.compute_phases[0], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end)
x = self._infer_cross_attn(weights.compute_phases[1], x, context, block_idx)
x = self._infer_ffn(weights.compute_phases[2], x, embed0)
x = self.infer_self_attn(weights.compute_phases[1], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end)
x = self.infer_cross_attn(weights.compute_phases[2], x, context, block_idx)
x = self.infer_ffn(weights.compute_phases[3], x, embed0)
return x
......@@ -229,7 +229,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
else:
self.derivative_approximation(self.blocks_cache_odd[block_idx], "self_attn_out", y_out)
attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa)
x, attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa)
if self.infer_conditional:
self.derivative_approximation(self.blocks_cache_even[block_idx], "cross_attn_out", attn_out)
else:
......@@ -369,7 +369,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
else:
self.args_odd.now_residual_tiny = y_out * gate_msa.squeeze(0)
attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa)
x, attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa)
y_out = self.infer_ffn(weights.blocks[block_idx].compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y_out, c_gate_msa)
......@@ -637,7 +637,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(weights.blocks[block_idx].compute_phases[0], embed0)
y_out = self.infer_self_attn(weights.blocks[block_idx].compute_phases[1], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa)
attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa)
x, attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa)
y_out = self.infer_ffn(weights.blocks[block_idx].compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y_out, c_gate_msa)
......
import math
import torch
import torch.cuda.amp as amp
from lightx2v.utils.envs import *
......
......@@ -116,7 +116,7 @@ class WanTransformerInfer(BaseTransformerInfer):
scale_msa,
)
elif cur_phase_idx == 2:
attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
x, attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
......@@ -169,7 +169,7 @@ class WanTransformerInfer(BaseTransformerInfer):
scale_msa,
)
elif cur_phase_idx == 2:
attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
x, attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
......@@ -222,7 +222,7 @@ class WanTransformerInfer(BaseTransformerInfer):
shift_msa,
scale_msa,
)
attn_out = self.infer_cross_attn(weights.compute_phases[2], x, context, y_out, gate_msa)
x, attn_out = self.infer_cross_attn(weights.compute_phases[2], x, context, y_out, gate_msa)
y = self.infer_ffn(weights.compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
return x
......@@ -371,7 +371,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if self.clean_cuda_cache:
del q, k, v, norm3_out, context, context_img
torch.cuda.empty_cache()
return attn_out
return x, attn_out
def infer_ffn(self, weights, x, attn_out, c_shift_msa, c_scale_msa):
x.add_(attn_out)
......@@ -387,22 +387,23 @@ class WanTransformerInfer(BaseTransformerInfer):
norm2_weight = 1 + c_scale_msa.squeeze(0)
norm2_bias = c_shift_msa.squeeze(0)
x = weights.norm2.apply(x)
norm2_out = weights.norm2.apply(x)
if GET_DTYPE() != "BF16":
x = x.float()
x.mul_(norm2_weight).add_(norm2_bias)
norm2_out = norm2_out.float()
norm2_out.mul_(norm2_weight).add_(norm2_bias)
if GET_DTYPE() != "BF16":
x = x.to(torch.bfloat16)
norm2_out = norm2_out.to(torch.bfloat16)
x = weights.ffn_0.apply(x)
y = weights.ffn_0.apply(norm2_out)
if self.clean_cuda_cache:
del norm2_out, x, norm2_weight, norm2_bias
torch.cuda.empty_cache()
x = torch.nn.functional.gelu(x, approximate="tanh")
y = torch.nn.functional.gelu(y, approximate="tanh")
if self.clean_cuda_cache:
torch.cuda.empty_cache()
x = weights.ffn_2.apply(x)
y = weights.ffn_2.apply(y)
return x
return y
def post_process(self, x, y, c_gate_msa):
if GET_DTYPE() != "BF16":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment