"vscode:/vscode.git/clone" did not exist on "50ac47f97e4d663910c06aeaf744992c3bb44a57"
Commit 9d551b87 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #82 from ModelTC/dev_FIX

Fix bugs
parents 8bc0da34 c8fce5b9
File mode changed from 100644 to 100755
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "block", "offload_granularity": "block",
"mm_config": { "mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F", "mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": true "weight_auto_quant": true
} }
} }
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "block", "offload_granularity": "block",
"mm_config": { "mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F", "mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": true "weight_auto_quant": true
} }
} }
...@@ -32,5 +32,7 @@ ...@@ -32,5 +32,7 @@
[-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
], ],
"use_ret_steps": true, "use_ret_steps": true,
"teacache_thresh": 0.26 "teacache_thresh": 0.26,
"rotary_chunk": true,
"clean_cuda_cache": true
} }
...@@ -27,12 +27,12 @@ class WanCausVidModel(WanModel): ...@@ -27,12 +27,12 @@ class WanCausVidModel(WanModel):
self.post_infer_class = WanPostInfer self.post_infer_class = WanPostInfer
self.transformer_infer_class = WanTransformerInferCausVid self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self): def _load_ckpt(self, use_bf16, skip_bf16):
use_bfloat16 = GET_DTYPE() == "BF16" use_bfloat16 = GET_DTYPE() == "BF16"
ckpt_path = os.path.join(self.model_path, "causal_model.pt") ckpt_path = os.path.join(self.model_path, "causal_model.pt")
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
# 文件不存在,调用父类的 _load_ckpt 方法 # 文件不存在,调用父类的 _load_ckpt 方法
return super()._load_ckpt() return super()._load_ckpt(use_bf16, skip_bf16)
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
......
...@@ -91,7 +91,7 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -91,7 +91,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
) )
return x return x
def _infer_self_attn(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end): def infer_self_attn(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0) norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
...@@ -135,7 +135,7 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -135,7 +135,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
return x return x
def _infer_cross_attn(self, weights, x, context, block_idx): def infer_cross_attn(self, weights, x, context, block_idx):
norm3_out = weights.norm3.apply(x) norm3_out = weights.norm3.apply(x)
if self.task == "i2v": if self.task == "i2v":
...@@ -195,7 +195,7 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -195,7 +195,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
return x return x
def _infer_ffn(self, weights, x, embed0): def infer_ffn(self, weights, x, embed0):
norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0)) y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0))
y = torch.nn.functional.gelu(y, approximate="tanh") y = torch.nn.functional.gelu(y, approximate="tanh")
...@@ -206,17 +206,15 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -206,17 +206,15 @@ class WanTransformerInferCausVid(WanTransformerInfer):
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
if embed0.dim() == 3: if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim modulation = weights.compute_phases[0].modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
embed0 = embed0.unsqueeze(0) # embed0 = embed0.unsqueeze(0) #
embed0 = (modulation + embed0).chunk(6, dim=1) embed0 = (modulation + embed0).chunk(6, dim=1)
embed0 = [ei.squeeze(1) for ei in embed0] embed0 = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2: elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1) embed0 = (weights.compute_phases[0].modulation.tensor + embed0).chunk(6, dim=1)
x = self._infer_self_attn(weights.compute_phases[0], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end) x = self.infer_self_attn(weights.compute_phases[1], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end)
x = self.infer_cross_attn(weights.compute_phases[2], x, context, block_idx)
x = self._infer_cross_attn(weights.compute_phases[1], x, context, block_idx) x = self.infer_ffn(weights.compute_phases[3], x, embed0)
x = self._infer_ffn(weights.compute_phases[2], x, embed0)
return x return x
...@@ -229,7 +229,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra ...@@ -229,7 +229,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
else: else:
self.derivative_approximation(self.blocks_cache_odd[block_idx], "self_attn_out", y_out) self.derivative_approximation(self.blocks_cache_odd[block_idx], "self_attn_out", y_out)
attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa) x, attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa)
if self.infer_conditional: if self.infer_conditional:
self.derivative_approximation(self.blocks_cache_even[block_idx], "cross_attn_out", attn_out) self.derivative_approximation(self.blocks_cache_even[block_idx], "cross_attn_out", attn_out)
else: else:
...@@ -369,7 +369,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer): ...@@ -369,7 +369,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
else: else:
self.args_odd.now_residual_tiny = y_out * gate_msa.squeeze(0) self.args_odd.now_residual_tiny = y_out * gate_msa.squeeze(0)
attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa) x, attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa)
y_out = self.infer_ffn(weights.blocks[block_idx].compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa) y_out = self.infer_ffn(weights.blocks[block_idx].compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y_out, c_gate_msa) x = self.post_process(x, y_out, c_gate_msa)
...@@ -637,7 +637,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra ...@@ -637,7 +637,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(weights.blocks[block_idx].compute_phases[0], embed0) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(weights.blocks[block_idx].compute_phases[0], embed0)
y_out = self.infer_self_attn(weights.blocks[block_idx].compute_phases[1], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa) y_out = self.infer_self_attn(weights.blocks[block_idx].compute_phases[1], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa)
attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa) x, attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa)
y_out = self.infer_ffn(weights.blocks[block_idx].compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa) y_out = self.infer_ffn(weights.blocks[block_idx].compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y_out, c_gate_msa) x = self.post_process(x, y_out, c_gate_msa)
......
import math import math
import torch import torch
import torch.cuda.amp as amp
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
......
...@@ -116,7 +116,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -116,7 +116,7 @@ class WanTransformerInfer(BaseTransformerInfer):
scale_msa, scale_msa,
) )
elif cur_phase_idx == 2: elif cur_phase_idx == 2:
attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa) x, attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 3: elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa) x = self.post_process(x, y, c_gate_msa)
...@@ -169,7 +169,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -169,7 +169,7 @@ class WanTransformerInfer(BaseTransformerInfer):
scale_msa, scale_msa,
) )
elif cur_phase_idx == 2: elif cur_phase_idx == 2:
attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa) x, attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 3: elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa) x = self.post_process(x, y, c_gate_msa)
...@@ -222,7 +222,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -222,7 +222,7 @@ class WanTransformerInfer(BaseTransformerInfer):
shift_msa, shift_msa,
scale_msa, scale_msa,
) )
attn_out = self.infer_cross_attn(weights.compute_phases[2], x, context, y_out, gate_msa) x, attn_out = self.infer_cross_attn(weights.compute_phases[2], x, context, y_out, gate_msa)
y = self.infer_ffn(weights.compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(weights.compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa) x = self.post_process(x, y, c_gate_msa)
return x return x
...@@ -371,7 +371,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -371,7 +371,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if self.clean_cuda_cache: if self.clean_cuda_cache:
del q, k, v, norm3_out, context, context_img del q, k, v, norm3_out, context, context_img
torch.cuda.empty_cache() torch.cuda.empty_cache()
return attn_out return x, attn_out
def infer_ffn(self, weights, x, attn_out, c_shift_msa, c_scale_msa): def infer_ffn(self, weights, x, attn_out, c_shift_msa, c_scale_msa):
x.add_(attn_out) x.add_(attn_out)
...@@ -387,22 +387,23 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -387,22 +387,23 @@ class WanTransformerInfer(BaseTransformerInfer):
norm2_weight = 1 + c_scale_msa.squeeze(0) norm2_weight = 1 + c_scale_msa.squeeze(0)
norm2_bias = c_shift_msa.squeeze(0) norm2_bias = c_shift_msa.squeeze(0)
x = weights.norm2.apply(x) norm2_out = weights.norm2.apply(x)
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
x = x.float() norm2_out = norm2_out.float()
x.mul_(norm2_weight).add_(norm2_bias) norm2_out.mul_(norm2_weight).add_(norm2_bias)
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
x = x.to(torch.bfloat16) norm2_out = norm2_out.to(torch.bfloat16)
x = weights.ffn_0.apply(x) y = weights.ffn_0.apply(norm2_out)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del norm2_out, x, norm2_weight, norm2_bias
torch.cuda.empty_cache() torch.cuda.empty_cache()
x = torch.nn.functional.gelu(x, approximate="tanh") y = torch.nn.functional.gelu(y, approximate="tanh")
if self.clean_cuda_cache: if self.clean_cuda_cache:
torch.cuda.empty_cache() torch.cuda.empty_cache()
x = weights.ffn_2.apply(x) y = weights.ffn_2.apply(y)
return x return y
def post_process(self, x, y, c_gate_msa): def post_process(self, x, y, c_gate_msa):
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment