Commit 820b4450 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix fp32-related bug in audio model

Fix fp32-related bug in audio model
parents a3d0f2d9 4389450a
...@@ -121,7 +121,7 @@ class PerceiverAttentionCA(nn.Module): ...@@ -121,7 +121,7 @@ class PerceiverAttentionCA(nn.Module):
x = self.norm_kv(x) x = self.norm_kv(x)
shift, scale, gate = (t_emb + self.shift_scale_gate).chunk(3, dim=1) shift, scale, gate = (t_emb + self.shift_scale_gate).chunk(3, dim=1)
latents = self.norm_q(latents) * (1 + scale) + shift latents = self.norm_q(latents) * (1 + scale) + shift
q = self.to_q(latents) q = self.to_q(latents.to(GET_DTYPE()))
k, v = self.to_kv(x).chunk(2, dim=-1) k, v = self.to_kv(x).chunk(2, dim=-1)
q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads) q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads)
k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads) k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads)
......
...@@ -3,16 +3,21 @@ import math ...@@ -3,16 +3,21 @@ import math
import torch import torch
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.utils.envs import *
class WanAudioPostInfer(WanPostInfer): class WanAudioPostInfer(WanPostInfer):
def __init__(self, config): def __init__(self, config):
self.out_dim = config["out_dim"] self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2) self.patch_size = (1, 2, 2)
self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, e, grid_sizes, valid_patch_length): def infer(self, weights, x, e, grid_sizes, valid_patch_length):
if e.dim() == 2: if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim modulation = weights.head_modulation.tensor # 1, 2, dim
...@@ -22,13 +27,23 @@ class WanAudioPostInfer(WanPostInfer): ...@@ -22,13 +27,23 @@ class WanAudioPostInfer(WanPostInfer):
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e] e = [ei.squeeze(1) for ei in e]
norm_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6).type_as(x) x = weights.norm.apply(x)
out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0)
x = weights.head.apply(out)
x = x[:, :valid_patch_length] if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype)
x.mul_(1 + e[1].squeeze()).add_(e[0].squeeze())
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.infer_dtype)
x = weights.head.apply(x)
x = x[:, :valid_patch_length]
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
if self.clean_cuda_cache:
del e, grid_sizes
torch.cuda.empty_cache()
return [u.float() for u in x] return [u.float() for u in x]
def unpatchify(self, x, grid_sizes): def unpatchify(self, x, grid_sizes):
......
import torch import torch
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import *
from ..utils import rope_params, sinusoidal_embedding_1d from ..utils import rope_params, sinusoidal_embedding_1d
...@@ -23,6 +24,8 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -23,6 +24,8 @@ class WanAudioPreInfer(WanPreInfer):
self.dim = config["dim"] self.dim = config["dim"]
self.text_len = config["text_len"] self.text_len = config["text_len"]
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def infer(self, weights, inputs, positive): def infer(self, weights, inputs, positive):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0) prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0)
...@@ -65,6 +68,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -65,6 +68,7 @@ class WanAudioPreInfer(WanPreInfer):
) )
ref_image_encoder = torch.concat([ref_image_encoder, zero_padding], dim=1) ref_image_encoder = torch.concat([ref_image_encoder, zero_padding], dim=1)
y = list(torch.unbind(ref_image_encoder, dim=0)) # 第一个batch维度变成list y = list(torch.unbind(ref_image_encoder, dim=0)) # 第一个batch维度变成list
# embeddings # embeddings
x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x] x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x]
x_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) x_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
...@@ -74,28 +78,40 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -74,28 +78,40 @@ class WanAudioPreInfer(WanPreInfer):
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x]) x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
valid_patch_length = x[0].size(0) valid_patch_length = x[0].size(0)
y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y] y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y]
y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y]) # y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
y = [u.flatten(2).transpose(1, 2).squeeze(0) for u in y] y = [u.flatten(2).transpose(1, 2).squeeze(0) for u in y]
x = [torch.cat([a, b], dim=0) for a, b in zip(x, y)] x = [torch.cat([a, b], dim=0) for a, b in zip(x, y)]
x = torch.stack(x, dim=0) x = torch.stack(x, dim=0)
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
embed = weights.time_embedding_0.apply(embed) # embed = weights.time_embedding_0.apply(embed)
if self.sensitive_layer_dtype != self.infer_dtype:
embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype))
else:
embed = weights.time_embedding_0.apply(embed)
embed = torch.nn.functional.silu(embed) embed = torch.nn.functional.silu(embed)
embed = weights.time_embedding_2.apply(embed) embed = weights.time_embedding_2.apply(embed)
embed0 = torch.nn.functional.silu(embed) embed0 = torch.nn.functional.silu(embed)
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)) embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
# text embeddings # text embeddings
stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context]) stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
out = weights.text_embedding_0.apply(stacked.squeeze(0)) if self.sensitive_layer_dtype != self.infer_dtype:
out = weights.text_embedding_0.apply(stacked.squeeze(0).to(self.sensitive_layer_dtype))
else:
out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh") out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out) context = weights.text_embedding_2.apply(out)
if self.clean_cuda_cache:
del out, stacked
torch.cuda.empty_cache()
if self.task == "i2v" and self.config.get("use_image_encoder", True): if self.task == "i2v" and self.config.get("use_image_encoder", True):
context_clip = weights.proj_0.apply(clip_fea) context_clip = weights.proj_0.apply(clip_fea)
if self.clean_cuda_cache:
del clip_fea
torch.cuda.empty_cache()
context_clip = weights.proj_1.apply(context_clip) context_clip = weights.proj_1.apply(context_clip)
context_clip = torch.nn.functional.gelu(context_clip, approximate="none") context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
context_clip = weights.proj_3.apply(context_clip) context_clip = weights.proj_3.apply(context_clip)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment