Commit c7eb4631 authored by wangshankun's avatar wangshankun
Browse files

Bug Fix: Fix incomplete parallel loading of audio model

parent dd958c79
{ {
"infer_steps": 4, "infer_steps": 4,
"target_fps": 16, "target_fps": 16,
"video_duration": 16, "video_duration": 12,
"audio_sr": 16000, "audio_sr": 16000,
"target_video_length": 81, "target_video_length": 81,
"target_height": 720, "target_height": 720,
...@@ -14,5 +14,6 @@ ...@@ -14,5 +14,6 @@
"sample_shift": 5, "sample_shift": 5,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": false, "cpu_offload": false,
"adaptive_resize": true,
"use_31_block": false "use_31_block": false
} }
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 12,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false,
"adaptive_resize": true,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses"
}
}
...@@ -7,14 +7,12 @@ import os ...@@ -7,14 +7,12 @@ import os
import safetensors import safetensors
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange from einops import rearrange
from transformers import AutoModel from transformers import AutoModel
import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -54,15 +52,40 @@ def load_pt_safetensors(in_path: str): ...@@ -54,15 +52,40 @@ def load_pt_safetensors(in_path: str):
return state_dict return state_dict
def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True): def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, seq_p_group=None):
import torch.distributed as dist model = model.to("cuda")
# 确定当前进程是否是(负责加载权重)
is_leader = False
if seq_p_group is not None and dist.is_initialized():
group_rank = dist.get_rank(group=seq_p_group)
if group_rank == 0:
is_leader = True
elif not dist.is_initialized() or dist.get_rank() == 0:
is_leader = True
if (dist.is_initialized() and dist.get_rank() == 0) or (not dist.is_initialized()): if is_leader:
state_dict = load_pt_safetensors(in_path) state_dict = load_pt_safetensors(in_path)
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=strict)
if dist.is_initialized():
# 将模型状态从领导者同步到组内所有其他进程
if seq_p_group is not None and dist.is_initialized():
dist.barrier(group=seq_p_group)
src_global_rank = dist.get_process_group_ranks(seq_p_group)[0]
for param in model.parameters():
dist.broadcast(param.data, src=src_global_rank, group=seq_p_group)
for buffer in model.buffers():
dist.broadcast(buffer.data, src=src_global_rank, group=seq_p_group)
elif dist.is_initialized():
dist.barrier() dist.barrier()
return model.to(dtype=GET_DTYPE(), device="cuda") for param in model.parameters():
dist.broadcast(param.data, src=0)
for buffer in model.buffers():
dist.broadcast(buffer.data, src=0)
return model.to(dtype=GET_DTYPE())
def linear_interpolation(features, output_len: int): def linear_interpolation(features, output_len: int):
......
...@@ -13,8 +13,6 @@ from lightx2v.models.networks.wan.weights.transformer_weights import ( ...@@ -13,8 +13,6 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from loguru import logger
class WanAudioModel(WanModel): class WanAudioModel(WanModel):
pre_weight_class = WanPreWeights pre_weight_class = WanPreWeights
......
...@@ -11,6 +11,7 @@ from lightx2v.utils.envs import * ...@@ -11,6 +11,7 @@ from lightx2v.utils.envs import *
from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
......
...@@ -27,9 +27,9 @@ def compute_freqs_audio(c, grid_sizes, freqs): ...@@ -27,9 +27,9 @@ def compute_freqs_audio(c, grid_sizes, freqs):
seq_len = f * h * w seq_len = f * h * w
freqs_i = torch.cat( freqs_i = torch.cat(
[ [
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), # 时间(帧)编码
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), # 空间(高度)编码
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), # 空间(宽度)编码
], ],
dim=-1, dim=-1,
).reshape(seq_len, 1, -1) ).reshape(seq_len, 1, -1)
......
...@@ -417,7 +417,6 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -417,7 +417,6 @@ class WanAudioRunner(WanRunner): # type:ignore
time_freq_dim=256, time_freq_dim=256,
projection_transformer_layers=4, projection_transformer_layers=4,
) )
audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False)
# Audio encoder # Audio encoder
cpu_offload = self.config.get("cpu_offload", False) cpu_offload = self.config.get("cpu_offload", False)
...@@ -432,6 +431,8 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -432,6 +431,8 @@ class WanAudioRunner(WanRunner): # type:ignore
else: else:
seq_p_group = None seq_p_group = None
audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False, seq_p_group=seq_p_group)
self._audio_adapter_pipe = AudioAdapterPipe( self._audio_adapter_pipe = AudioAdapterPipe(
audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload, seq_p_group=seq_p_group audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload, seq_p_group=seq_p_group
) )
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export CUDA_VISIBLE_DEVICES=0,1,2,3
export TORCH_CUDA_ARCH_LIST="9.0"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export ENABLE_GRAPH_MODE=false
#for debugging
#export TORCH_NCCL_BLOCKING_WAIT=1 #启用 NCCL 阻塞等待模式(否则 watchdog 会杀死卡顿的进程)
#export NCCL_BLOCKING_WAIT_TIMEOUT=1800 #设置 watchdog 的等待超时
torchrun --nproc-per-node 4 -m lightx2v.infer \
--model_cls wan2.1_audio \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/audio_driven/wan_i2v_audio_dist.json \
--prompt "The video features a old lady is saying something and knitting a sweater." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment