Commit 984cd6c9 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix audio offload bug

Fix audio offload bug
parents 348822d9 77bef6e8
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 16,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "block",
"t5_cpu_offload": true,
"offload_ratio_val": 1,
"t5_offload_granularity": "block",
"use_tiling_vae": true
}
...@@ -2,6 +2,8 @@ try: ...@@ -2,6 +2,8 @@ try:
import flash_attn import flash_attn
except ModuleNotFoundError: except ModuleNotFoundError:
flash_attn = None flash_attn = None
import os
import safetensors
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -9,11 +11,6 @@ import torch.nn.functional as F ...@@ -9,11 +11,6 @@ import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange from einops import rearrange
from transformers import AutoModel from transformers import AutoModel
from loguru import logger
import os
import safetensors
from typing import List, Optional, Tuple, Union
def load_safetensors(in_path: str): def load_safetensors(in_path: str):
...@@ -370,13 +367,12 @@ class AudioAdapter(nn.Module): ...@@ -370,13 +367,12 @@ class AudioAdapter(nn.Module):
class AudioAdapterPipe: class AudioAdapterPipe:
def __init__( def __init__(
self, audio_adapter: AudioAdapter, audio_encoder_repo: str = "microsoft/wavlm-base-plus", dtype=torch.float32, device="cuda", generator=None, tgt_fps: int = 15, weight: float = 1.0 self, audio_adapter: AudioAdapter, audio_encoder_repo: str = "microsoft/wavlm-base-plus", dtype=torch.float32, device="cuda", tgt_fps: int = 15, weight: float = 1.0, cpu_offload: bool = False
) -> None: ) -> None:
self.audio_adapter = audio_adapter self.audio_adapter = audio_adapter
self.dtype = dtype self.dtype = dtype
self.device = device
self.generator = generator
self.audio_encoder_dtype = torch.float16 self.audio_encoder_dtype = torch.float16
self.cpu_offload = cpu_offload
##音频编码器 ##音频编码器
self.audio_encoder = AutoModel.from_pretrained(audio_encoder_repo) self.audio_encoder = AutoModel.from_pretrained(audio_encoder_repo)
...@@ -403,11 +399,14 @@ class AudioAdapterPipe: ...@@ -403,11 +399,14 @@ class AudioAdapterPipe:
audio_length = int(50 / self.tgt_fps * video_frame) audio_length = int(50 / self.tgt_fps * video_frame)
with torch.no_grad(): with torch.no_grad():
audio_input_feat = audio_input_feat.to(self.device, self.audio_encoder_dtype)
try: try:
audio_feat = self.audio_encoder(audio_input_feat, return_dict=True).last_hidden_state if self.cpu_offload:
self.audio_encoder = self.audio_encoder.to("cuda")
audio_feat = self.audio_encoder(audio_input_feat.to(self.audio_encoder_dtype), return_dict=True).last_hidden_state
if self.cpu_offload:
self.audio_encoder = self.audio_encoder.to("cpu")
except Exception as err: except Exception as err:
audio_feat = torch.rand(1, audio_length, self.audio_feature_dim).to(self.device) audio_feat = torch.rand(1, audio_length, self.audio_feature_dim).to("cuda")
print(err) print(err)
audio_feat = audio_feat.to(self.dtype) audio_feat = audio_feat.to(self.dtype)
if dropout_cond is not None: if dropout_cond is not None:
......
...@@ -2,32 +2,27 @@ import os ...@@ -2,32 +2,27 @@ import os
import gc import gc
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms.functional as TF import subprocess
import torchaudio as ta
from PIL import Image from PIL import Image
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Tuple, Union, List, Dict, Any from typing import Optional, Tuple, List, Dict, Any
from dataclasses import dataclass from dataclasses import dataclass
from loguru import logger
from einops import rearrange
from transformers import AutoFeatureExtractor
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner, MultiModelStruct
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.networks.wan.audio_model import WanAudioModel, Wan22MoeAudioModel from lightx2v.models.networks.wan.audio_model import WanAudioModel, Wan22MoeAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from .wan_runner import MultiModelStruct
from loguru import logger
from einops import rearrange
import torchaudio as ta
from transformers import AutoFeatureExtractor
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
import subprocess
import warnings
@contextmanager @contextmanager
...@@ -424,9 +419,13 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -424,9 +419,13 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False) audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False)
# Audio encoder # Audio encoder
device = torch.device("cuda") cpu_offload = self.config.get("cpu_offload", False)
if cpu_offload:
device = torch.device("cpu")
else:
device = torch.device("cuda")
audio_encoder_repo = self.config["model_path"] + "/audio_encoder" audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0) self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, weight=1.0, cpu_offload=cpu_offload)
return self._audio_adapter_pipe return self._audio_adapter_pipe
...@@ -622,7 +621,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -622,7 +621,7 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img = Image.open(config.image_path) ref_img = Image.open(config.image_path)
ref_img = (np.array(ref_img).astype(np.float32) - 127.5) / 127.5 ref_img = (np.array(ref_img).astype(np.float32) - 127.5) / 127.5
ref_img = torch.from_numpy(ref_img).to(vae_model.device) ref_img = torch.from_numpy(ref_img).cuda()
ref_img = rearrange(ref_img, "H W C -> 1 C H W") ref_img = rearrange(ref_img, "H W C -> 1 C H W")
ref_img = ref_img[:, :3] ref_img = ref_img[:, :3]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment