Commit bff9bd05 authored by helloyongyang's avatar helloyongyang
Browse files

update wan infer code

parent 1c065c06
...@@ -27,6 +27,8 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -27,6 +27,8 @@ class LNWeightTemplate(metaclass=ABCMeta):
self.bias = None self.bias = None
def _calculate_size(self): def _calculate_size(self):
if self.weight is None:
return 0
if self.bias is not None: if self.bias is not None:
return self.weight.numel() * self.weight.element_size() + self.bias.numel() * self.bias.element_size() return self.weight.numel() * self.weight.element_size() + self.bias.numel() * self.bias.element_size()
return self.weight.numel() * self.weight.element_size() return self.weight.numel() * self.weight.element_size()
...@@ -104,4 +106,4 @@ class LNWeight(LNWeightTemplate): ...@@ -104,4 +106,4 @@ class LNWeight(LNWeightTemplate):
).to(torch.bfloat16) ).to(torch.bfloat16)
else: else:
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps) input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
return input_tensor return input_tensor
\ No newline at end of file
...@@ -91,13 +91,7 @@ class WanTransformerInfer: ...@@ -91,13 +91,7 @@ class WanTransformerInfer:
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
weights.blocks[block_idx].modulation.to_cuda() weights.blocks[block_idx].modulation.to_cuda()
if embed0.dim() == 3: shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context)
modulation = weights.blocks[block_idx].modulation.tensor.unsqueeze(2)
current_embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in current_embed0]
elif embed0.dim() == 2:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.blocks[block_idx].modulation.tensor + embed0).chunk(6, dim=1)
for phase_idx in range(self.phases_num): for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0: if block_idx == 0 and phase_idx == 0:
...@@ -108,22 +102,12 @@ class WanTransformerInfer: ...@@ -108,22 +102,12 @@ class WanTransformerInfer:
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
cur_phase_idx, cur_phase = self.weights_stream_mgr.active_weights[0] cur_phase_idx, cur_phase = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0: if cur_phase_idx == 0:
x = self._infer_self_attn( y_out = self.infer_phase_2(cur_phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa)
cur_phase,
x,
shift_msa,
scale_msa,
gate_msa,
grid_sizes,
freqs,
seq_lens,
)
elif cur_phase_idx == 1: elif cur_phase_idx == 1:
x = self._infer_cross_attn(cur_phase, x, context) attn_out = self.infer_phase_3(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 2: elif cur_phase_idx == 2:
x = self._infer_ffn(cur_phase, x, c_shift_msa, c_scale_msa, c_gate_msa) y = self.infer_phase_4(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.infer_phase_5(x, y, c_gate_msa)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1 is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase: if not is_last_phase:
...@@ -146,12 +130,7 @@ class WanTransformerInfer: ...@@ -146,12 +130,7 @@ class WanTransformerInfer:
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
weights.blocks[block_idx].modulation.to_cuda() weights.blocks[block_idx].modulation.to_cuda()
if embed0.dim() == 3: shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context)
modulation = weights.blocks[block_idx].modulation.tensor.unsqueeze(2)
current_embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in current_embed0]
elif embed0.dim() == 2:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.blocks[block_idx].modulation.tensor + embed0).chunk(6, dim=1)
for phase_idx in range(self.weights_stream_mgr.phases_num): for phase_idx in range(self.weights_stream_mgr.phases_num):
if block_idx == 0 and phase_idx == 0: if block_idx == 0 and phase_idx == 0:
...@@ -170,20 +149,12 @@ class WanTransformerInfer: ...@@ -170,20 +149,12 @@ class WanTransformerInfer:
) = self.weights_stream_mgr.active_weights[0] ) = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0: if cur_phase_idx == 0:
x = self._infer_self_attn( y_out = self.infer_phase_2(cur_phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa)
cur_phase,
x,
shift_msa,
scale_msa,
gate_msa,
grid_sizes,
freqs,
seq_lens,
)
elif cur_phase_idx == 1: elif cur_phase_idx == 1:
x = self._infer_cross_attn(cur_phase, x, context) attn_out = self.infer_phase_3(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 2: elif cur_phase_idx == 2:
x = self._infer_ffn(cur_phase, x, c_shift_msa, c_scale_msa, c_gate_msa) y = self.infer_phase_4(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.infer_phase_5(x, y, c_gate_msa)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1): if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
...@@ -213,7 +184,24 @@ class WanTransformerInfer: ...@@ -213,7 +184,24 @@ class WanTransformerInfer:
) )
return x return x
def _infer_self_attn(self, weights, x, shift_msa, scale_msa, gate_msa, grid_sizes, freqs, seq_lens): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
y_out = self.infer_phase_2(weights.compute_phases[0], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa)
attn_out = self.infer_phase_3(weights.compute_phases[1], x, context, y_out, gate_msa)
y = self.infer_phase_4(weights.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
x = self.infer_phase_5(x, y, c_gate_msa)
return x
def infer_phase_1(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2)
embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.modulation.tensor + embed0).chunk(6, dim=1)
return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
def infer_phase_2(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
if hasattr(weights, "smooth_norm1_weight"): if hasattr(weights, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa) * weights.smooth_norm1_weight.tensor norm1_weight = (1 + scale_msa) * weights.smooth_norm1_weight.tensor
norm1_bias = shift_msa * weights.smooth_norm1_bias.tensor norm1_bias = shift_msa * weights.smooth_norm1_bias.tensor
...@@ -269,14 +257,14 @@ class WanTransformerInfer: ...@@ -269,14 +257,14 @@ class WanTransformerInfer:
) )
y = weights.self_attn_o.apply(attn_out) y = weights.self_attn_o.apply(attn_out)
return y
def infer_phase_3(self, weights, x, context, y_out, gate_msa):
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
x = x.float() + y.float() * gate_msa.squeeze(0) x = x.float() + y_out.float() * gate_msa.squeeze(0)
else: else:
x.add_(y * gate_msa.squeeze(0)) x.add_(y_out * gate_msa.squeeze(0))
return x
def _infer_cross_attn(self, weights, x, context):
norm3_out = weights.norm3.apply(x) norm3_out = weights.norm3.apply(x)
if self.task == "i2v": if self.task == "i2v":
context_img = context[:257] context_img = context[:257]
...@@ -331,10 +319,10 @@ class WanTransformerInfer: ...@@ -331,10 +319,10 @@ class WanTransformerInfer:
attn_out = attn_out + img_attn_out attn_out = attn_out + img_attn_out
attn_out = weights.cross_attn_o.apply(attn_out) attn_out = weights.cross_attn_o.apply(attn_out)
x.add_(attn_out) return attn_out
return x
def _infer_ffn(self, weights, x, c_shift_msa, c_scale_msa, c_gate_msa): def infer_phase_4(self, weights, x, attn_out, c_shift_msa, c_scale_msa):
x.add_(attn_out)
if hasattr(weights, "smooth_norm2_weight"): if hasattr(weights, "smooth_norm2_weight"):
norm2_weight = (1 + c_scale_msa.squeeze(0)) * weights.smooth_norm2_weight.tensor norm2_weight = (1 + c_scale_msa.squeeze(0)) * weights.smooth_norm2_weight.tensor
norm2_bias = c_shift_msa.squeeze(0) * weights.smooth_norm2_bias.tensor norm2_bias = c_shift_msa.squeeze(0) * weights.smooth_norm2_bias.tensor
...@@ -352,31 +340,11 @@ class WanTransformerInfer: ...@@ -352,31 +340,11 @@ class WanTransformerInfer:
y = weights.ffn_0.apply(norm2_out) y = weights.ffn_0.apply(norm2_out)
y = torch.nn.functional.gelu(y, approximate="tanh") y = torch.nn.functional.gelu(y, approximate="tanh")
y = weights.ffn_2.apply(y) y = weights.ffn_2.apply(y)
return y
def infer_phase_5(self, x, y, c_gate_msa):
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
x = x.float() + y.float() * c_gate_msa.squeeze(0) x = x.float() + y.float() * c_gate_msa.squeeze(0)
else: else:
x.add_(y * c_gate_msa.squeeze(0)) x.add_(y * c_gate_msa.squeeze(0))
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2)
embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.modulation.tensor + embed0).chunk(6, dim=1)
x = self._infer_self_attn(
weights.compute_phases[0],
x,
shift_msa,
scale_msa,
gate_msa,
grid_sizes,
freqs,
seq_lens,
)
x = self._infer_cross_attn(weights.compute_phases[1], x, context)
x = self._infer_ffn(weights.compute_phases[2], x, c_shift_msa, c_scale_msa, c_gate_msa)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment