Commit 77bef6e8 authored by gushiqiao's avatar gushiqiao
Browse files

Fix audio offload bug

parent 348822d9
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 16,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "block",
"t5_cpu_offload": true,
"offload_ratio_val": 1,
"t5_offload_granularity": "block",
"use_tiling_vae": true
}
......@@ -2,6 +2,8 @@ try:
import flash_attn
except ModuleNotFoundError:
flash_attn = None
import os
import safetensors
import math
import torch
import torch.nn as nn
......@@ -9,11 +11,6 @@ import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from transformers import AutoModel
from loguru import logger
import os
import safetensors
from typing import List, Optional, Tuple, Union
def load_safetensors(in_path: str):
......@@ -370,13 +367,12 @@ class AudioAdapter(nn.Module):
class AudioAdapterPipe:
def __init__(
self, audio_adapter: AudioAdapter, audio_encoder_repo: str = "microsoft/wavlm-base-plus", dtype=torch.float32, device="cuda", generator=None, tgt_fps: int = 15, weight: float = 1.0
self, audio_adapter: AudioAdapter, audio_encoder_repo: str = "microsoft/wavlm-base-plus", dtype=torch.float32, device="cuda", tgt_fps: int = 15, weight: float = 1.0, cpu_offload: bool = False
) -> None:
self.audio_adapter = audio_adapter
self.dtype = dtype
self.device = device
self.generator = generator
self.audio_encoder_dtype = torch.float16
self.cpu_offload = cpu_offload
##音频编码器
self.audio_encoder = AutoModel.from_pretrained(audio_encoder_repo)
......@@ -403,11 +399,14 @@ class AudioAdapterPipe:
audio_length = int(50 / self.tgt_fps * video_frame)
with torch.no_grad():
audio_input_feat = audio_input_feat.to(self.device, self.audio_encoder_dtype)
try:
audio_feat = self.audio_encoder(audio_input_feat, return_dict=True).last_hidden_state
if self.cpu_offload:
self.audio_encoder = self.audio_encoder.to("cuda")
audio_feat = self.audio_encoder(audio_input_feat.to(self.audio_encoder_dtype), return_dict=True).last_hidden_state
if self.cpu_offload:
self.audio_encoder = self.audio_encoder.to("cpu")
except Exception as err:
audio_feat = torch.rand(1, audio_length, self.audio_feature_dim).to(self.device)
audio_feat = torch.rand(1, audio_length, self.audio_feature_dim).to("cuda")
print(err)
audio_feat = audio_feat.to(self.dtype)
if dropout_cond is not None:
......
......@@ -2,32 +2,27 @@ import os
import gc
import numpy as np
import torch
import torchvision.transforms.functional as TF
import subprocess
import torchaudio as ta
from PIL import Image
from contextlib import contextmanager
from typing import Optional, Tuple, Union, List, Dict, Any
from typing import Optional, Tuple, List, Dict, Any
from dataclasses import dataclass
from loguru import logger
from einops import rearrange
from transformers import AutoFeatureExtractor
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner, MultiModelStruct
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.networks.wan.audio_model import WanAudioModel, Wan22MoeAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from .wan_runner import MultiModelStruct
from loguru import logger
from einops import rearrange
import torchaudio as ta
from transformers import AutoFeatureExtractor
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
import subprocess
import warnings
@contextmanager
......@@ -424,9 +419,13 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False)
# Audio encoder
device = torch.device("cuda")
cpu_offload = self.config.get("cpu_offload", False)
if cpu_offload:
device = torch.device("cpu")
else:
device = torch.device("cuda")
audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0)
self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, weight=1.0, cpu_offload=cpu_offload)
return self._audio_adapter_pipe
......@@ -622,7 +621,7 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img = Image.open(config.image_path)
ref_img = (np.array(ref_img).astype(np.float32) - 127.5) / 127.5
ref_img = torch.from_numpy(ref_img).to(vae_model.device)
ref_img = torch.from_numpy(ref_img).cuda()
ref_img = rearrange(ref_img, "H W C -> 1 C H W")
ref_img = ref_img[:, :3]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment