Commit 92f067f1 authored by gushiqiao's avatar gushiqiao
Browse files

Fix audio model compile and offload bugs

parent 00962c67
...@@ -113,7 +113,7 @@ class PerceiverAttentionCA(nn.Module): ...@@ -113,7 +113,7 @@ class PerceiverAttentionCA(nn.Module):
shift_scale_gate = torch.zeros((1, 3, inner_dim)) shift_scale_gate = torch.zeros((1, 3, inner_dim))
shift_scale_gate[:, 2] = 1 shift_scale_gate[:, 2] = 1
self.register_buffer("shift_scale_gate", shift_scale_gate, persistent=False) self.register_buffer("shift_scale_gate", shift_scale_gate, persistent=False)
def forward(self, x, latents, t_emb, q_lens, k_lens): def forward(self, x, latents, t_emb, q_lens, k_lens):
"""x shape (batchsize, latent_frame, audio_tokens_per_latent, """x shape (batchsize, latent_frame, audio_tokens_per_latent,
model_dim) latents (batchsize, length, model_dim)""" model_dim) latents (batchsize, length, model_dim)"""
......
...@@ -83,14 +83,14 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -83,14 +83,14 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k return cu_seqlens_q, cu_seqlens_k
def compute_freqs(self, q, grid_sizes, freqs): def compute_freqs(self, q, grid_sizes, freqs):
if "audio" in self.config.get("model_cls", ""): if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else: else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i return freqs_i
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks) return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks)
...@@ -147,7 +147,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -147,7 +147,7 @@ class WanTransformerInfer(BaseTransformerInfer):
) )
if audio_dit_blocks: if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks) x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
self.weights_stream_mgr.swap_weights() self.weights_stream_mgr.swap_weights()
if block_idx == self.blocks_num - 1: if block_idx == self.blocks_num - 1:
...@@ -155,7 +155,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -155,7 +155,6 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr._async_prefetch_block(weights.blocks) self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del grid_sizes, embed, embed0, seq_lens, freqs, context del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -295,9 +294,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -295,9 +294,7 @@ class WanTransformerInfer(BaseTransformerInfer):
for ipa_out in audio_dit_blocks: for ipa_out in audio_dit_blocks:
if block_idx in ipa_out: if block_idx in ipa_out:
cur_modify = ipa_out[block_idx] cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
grid_sizes,
**cur_modify["kwargs"])
return x return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
......
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import subprocess import subprocess
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -738,7 +738,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -738,7 +738,7 @@ class WanAudioRunner(WanRunner): # type:ignore
prev_video=prev_video, prev_video=prev_video,
prev_frame_length=5, prev_frame_length=5,
segment_idx=0, segment_idx=0,
total_steps=1 total_steps=1,
) )
# Final cleanup # Final cleanup
self.end_run() self.end_run()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment