Commit c7eb4631 authored by wangshankun's avatar wangshankun
Browse files

Bug Fix: Fix incomplete parallel loading of audio model

parent dd958c79
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 16,
"video_duration": 12,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
......@@ -14,5 +14,6 @@
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"adaptive_resize": true,
"use_31_block": false
}
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 12,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false,
"adaptive_resize": true,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses"
}
}
......@@ -7,14 +7,12 @@ import os
import safetensors
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from transformers import AutoModel
import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import *
......@@ -54,15 +52,40 @@ def load_pt_safetensors(in_path: str):
return state_dict
def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
import torch.distributed as dist
def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, seq_p_group=None):
model = model.to("cuda")
# 确定当前进程是否是(负责加载权重)
is_leader = False
if seq_p_group is not None and dist.is_initialized():
group_rank = dist.get_rank(group=seq_p_group)
if group_rank == 0:
is_leader = True
elif not dist.is_initialized() or dist.get_rank() == 0:
is_leader = True
if (dist.is_initialized() and dist.get_rank() == 0) or (not dist.is_initialized()):
if is_leader:
state_dict = load_pt_safetensors(in_path)
model.load_state_dict(state_dict, strict=strict)
if dist.is_initialized():
# 将模型状态从领导者同步到组内所有其他进程
if seq_p_group is not None and dist.is_initialized():
dist.barrier(group=seq_p_group)
src_global_rank = dist.get_process_group_ranks(seq_p_group)[0]
for param in model.parameters():
dist.broadcast(param.data, src=src_global_rank, group=seq_p_group)
for buffer in model.buffers():
dist.broadcast(buffer.data, src=src_global_rank, group=seq_p_group)
elif dist.is_initialized():
dist.barrier()
return model.to(dtype=GET_DTYPE(), device="cuda")
for param in model.parameters():
dist.broadcast(param.data, src=0)
for buffer in model.buffers():
dist.broadcast(buffer.data, src=0)
return model.to(dtype=GET_DTYPE())
def linear_interpolation(features, output_len: int):
......
......@@ -13,8 +13,6 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from loguru import logger
class WanAudioModel(WanModel):
pre_weight_class = WanPreWeights
......
......@@ -11,6 +11,7 @@ from lightx2v.utils.envs import *
from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio
class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config):
self.config = config
......
......@@ -27,9 +27,9 @@ def compute_freqs_audio(c, grid_sizes, freqs):
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), # 时间(帧)编码
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), # 空间(高度)编码
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), # 空间(宽度)编码
],
dim=-1,
).reshape(seq_len, 1, -1)
......
......@@ -417,7 +417,6 @@ class WanAudioRunner(WanRunner): # type:ignore
time_freq_dim=256,
projection_transformer_layers=4,
)
audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False)
# Audio encoder
cpu_offload = self.config.get("cpu_offload", False)
......@@ -432,6 +431,8 @@ class WanAudioRunner(WanRunner): # type:ignore
else:
seq_p_group = None
audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False, seq_p_group=seq_p_group)
self._audio_adapter_pipe = AudioAdapterPipe(
audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload, seq_p_group=seq_p_group
)
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export CUDA_VISIBLE_DEVICES=0,1,2,3
export TORCH_CUDA_ARCH_LIST="9.0"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export ENABLE_GRAPH_MODE=false
#for debugging
#export TORCH_NCCL_BLOCKING_WAIT=1 #启用 NCCL 阻塞等待模式(否则 watchdog 会杀死卡顿的进程)
#export NCCL_BLOCKING_WAIT_TIMEOUT=1800 #设置 watchdog 的等待超时
torchrun --nproc-per-node 4 -m lightx2v.infer \
--model_cls wan2.1_audio \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/audio_driven/wan_i2v_audio_dist.json \
--prompt "The video features a old lady is saying something and knitting a sweater." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment