Commit 1a881d63 authored by helloyongyang's avatar helloyongyang
Browse files

重构并行模块

parent 18e2b23a
import os
import torch
from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
......
import torch
from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer
import torch.distributed as dist
import torch.nn.functional as F
from lightx2v.models.networks.wan.infer.utils import compute_freqs_dist, compute_freqs_audio_dist
class WanTransformerDistInfer(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
x = self.dist_pre_process(x)
x = super().infer(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks)
x = self.dist_post_process(x)
return x
def compute_freqs(self, q, grid_sizes, freqs):
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
return freqs_i
def dist_pre_process(self, x):
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
padding_size = (world_size - (x.shape[0] % world_size)) % world_size
if padding_size > 0:
# 使用 F.pad 填充第一维
x = F.pad(x, (0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
x = torch.chunk(x, world_size, dim=0)[cur_rank]
return x
def dist_post_process(self, x):
# 获取当前进程的世界大小
world_size = dist.get_world_size()
# 创建一个列表,用于存储所有进程的输出
gathered_x = [torch.empty_like(x) for _ in range(world_size)]
# 收集所有进程的输出
dist.all_gather(gathered_x, x)
# 在指定的维度上合并所有进程的输出
combined_output = torch.cat(gathered_x, dim=0)
return combined_output # 返回合并后的输出
......@@ -318,6 +318,13 @@ class WanTransformerInfer(BaseTransformerInfer):
return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
def compute_freqs(self, q, grid_sizes, freqs):
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
def infer_self_attn(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
if hasattr(weights, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa.squeeze(0)) * weights.smooth_norm1_weight.tensor
......@@ -342,16 +349,7 @@ class WanTransformerInfer(BaseTransformerInfer):
k = weights.self_attn_norm_k.apply(weights.self_attn_k.apply(norm1_out)).view(s, n, d)
v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
if not self.parallel_attention:
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
else:
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
freqs_i = self.compute_freqs(q, grid_sizes, freqs)
freqs_i = self.zero_temporal_component_in_3DRoPE(seq_lens, freqs_i)
......@@ -365,7 +363,16 @@ class WanTransformerInfer(BaseTransformerInfer):
del freqs_i, norm1_out, norm1_weight, norm1_bias
torch.cuda.empty_cache()
if not self.parallel_attention:
if self.config.get("parallel_attn_type", None):
attn_out = weights.self_attn_1_parallel.apply(
q=q,
k=k,
v=v,
img_qkv_len=q.shape[0],
cu_seqlens_qkv=cu_seqlens_q,
attention_module=weights.self_attn_1,
)
else:
attn_out = weights.self_attn_1.apply(
q=q,
k=k,
......@@ -377,15 +384,6 @@ class WanTransformerInfer(BaseTransformerInfer):
model_cls=self.config["model_cls"],
mask_map=self.mask_map,
)
else:
attn_out = self.parallel_attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
img_qkv_len=q.shape[0],
cu_seqlens_qkv=cu_seqlens_q,
)
y = weights.self_attn_o.apply(attn_out)
......
......@@ -2,7 +2,7 @@ import os
import torch
import glob
import json
from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.common.ops.attn import MaskMap
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
......@@ -22,9 +22,8 @@ from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import
WanTransformerInferDualBlock,
WanTransformerInferDynamicBlock,
)
from lightx2v.models.networks.wan.infer.dist_infer.transformer_infer import WanTransformerDistInfer
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
from loguru import logger
......@@ -58,35 +57,30 @@ class WanModel:
self._init_weights()
self._init_infer()
if config["parallel_attn_type"]:
if config["parallel_attn_type"] == "ulysses":
ulysses_dist_wrap.parallelize_wan(self)
elif config["parallel_attn_type"] == "ring":
ring_dist_wrap.parallelize_wan(self)
else:
raise Exception(f"Unsuppotred parallel_attn_type")
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = WanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = WanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = WanTransformerInferCustomCaching
elif self.config["feature_caching"] == "FirstBlock":
self.transformer_infer_class = WanTransformerInferFirstBlock
elif self.config["feature_caching"] == "DualBlock":
self.transformer_infer_class = WanTransformerInferDualBlock
elif self.config["feature_caching"] == "DynamicBlock":
self.transformer_infer_class = WanTransformerInferDynamicBlock
if self.config.get("parallel_attn_type", None):
self.transformer_infer_class = WanTransformerDistInfer
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = WanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = WanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = WanTransformerInferCustomCaching
elif self.config["feature_caching"] == "FirstBlock":
self.transformer_infer_class = WanTransformerInferFirstBlock
elif self.config["feature_caching"] == "DualBlock":
self.transformer_infer_class = WanTransformerInferDualBlock
elif self.config["feature_caching"] == "DynamicBlock":
self.transformer_infer_class = WanTransformerInferDynamicBlock
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_safetensor_to_dict(self, file_path, use_bf16, skip_bf16):
with safe_open(file_path, framework="pt") as f:
......
......@@ -190,6 +190,10 @@ class WanSelfAttention(WeightModule):
self.self_attn_1.load(sparge_ckpt)
else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]())
if self.config.get("parallel_attn_type", None):
self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config["parallel_attn_type"]]())
if self.quant_method in ["advanced_ptq"]:
self.add_module(
"smooth_norm1_weight",
......
......@@ -235,8 +235,10 @@ class DefaultRunner(BaseRunner):
fps = self.config["video_frame_interpolation"]["target_fps"]
else:
fps = self.config.get("fps", 16)
logger.info(f"Saving video to {self.config.save_video_path}")
save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg") # type: ignore
if not self.config.get("parallel_attn_type", None) or dist.get_rank() == 0:
logger.info(f"Saving video to {self.config.save_video_path}")
save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg") # type: ignore
del latents, generator
torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment