Unverified Commit 9b13cab2 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Update wan infer (#524)

parent d242358f
......@@ -5,6 +5,7 @@ import torch
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import *
from .triton_ops import fuse_scale_shift_kernel
from .utils import apply_wan_rope_with_chunk, apply_wan_rope_with_flashinfer, apply_wan_rope_with_torch
......@@ -135,16 +136,15 @@ class WanTransformerInfer(BaseTransformerInfer):
if hasattr(phase, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa.squeeze()) * phase.smooth_norm1_weight.tensor
norm1_bias = shift_msa.squeeze() * phase.smooth_norm1_bias.tensor
norm1_out = phase.norm1.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.sensitive_layer_dtype)
norm1_out.mul_(norm1_weight).add_(norm1_bias)
else:
norm1_weight = 1 + scale_msa.squeeze()
norm1_bias = shift_msa.squeeze()
norm1_out = phase.norm1.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.sensitive_layer_dtype)
norm1_out.mul_(norm1_weight).add_(norm1_bias)
norm1_out = phase.norm1.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.sensitive_layer_dtype)
norm1_out = fuse_scale_shift_kernel(norm1_out, scale=scale_msa, shift=shift_msa).squeeze(0)
if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.infer_dtype)
......@@ -274,14 +274,16 @@ class WanTransformerInfer(BaseTransformerInfer):
if hasattr(phase, "smooth_norm2_weight"):
norm2_weight = (1 + c_scale_msa.squeeze()) * phase.smooth_norm2_weight.tensor
norm2_bias = c_shift_msa.squeeze() * phase.smooth_norm2_bias.tensor
norm2_out = phase.norm2.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.sensitive_layer_dtype)
norm2_out.mul_(norm2_weight).add_(norm2_bias)
else:
norm2_weight = 1 + c_scale_msa.squeeze()
norm2_bias = c_shift_msa.squeeze()
norm2_out = phase.norm2.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.sensitive_layer_dtype)
norm2_out = fuse_scale_shift_kernel(norm2_out, scale=c_scale_msa, shift=c_shift_msa).squeeze(0)
norm2_out = phase.norm2.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.sensitive_layer_dtype)
norm2_out.mul_(norm2_weight).add_(norm2_bias)
if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.infer_dtype)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment