Commit 60f0b6c1 authored by helloyongyang's avatar helloyongyang
Browse files

Support cfg parallel & hybrid parallel (cfg + seq)

parent a395cc0a
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses",
"cfg_p_size": 2
}
}
......@@ -11,6 +11,8 @@
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"parallel_attn_type": "ring",
"parallel_vae": true
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ring"
}
}
......@@ -11,6 +11,8 @@
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"parallel_attn_type": "ulysses",
"parallel_vae": true
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses"
}
}
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"parallel": {
"cfg_p_size": 2
}
}
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ring",
"cfg_p_size": 2
}
}
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses",
"cfg_p_size": 2
}
}
......@@ -12,6 +12,8 @@
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"parallel_attn_type": "ring",
"parallel_vae": true
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ring"
}
}
......@@ -12,6 +12,8 @@
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"parallel_attn_type": "ulysses",
"parallel_vae": true
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses"
}
}
......@@ -38,7 +38,7 @@ class RingAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None):
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None):
"""
执行 Ring 注意力机制,结合图像和文本的查询、键和值。
......@@ -54,8 +54,8 @@ class RingAttnWeight(AttnWeightTemplate):
torch.Tensor: 计算得到的注意力结果
"""
# 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
cur_rank = dist.get_rank(seq_p_group)
world_size = dist.get_world_size(seq_p_group)
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
......@@ -67,7 +67,7 @@ class RingAttnWeight(AttnWeightTemplate):
# if RING_COMM is None:
# init_ring_comm()
RING_COMM = RingComm()
RING_COMM = RingComm(seq_p_group)
# if len(cu_seqlens_qkv) == 3:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
......
......@@ -10,7 +10,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None):
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
......@@ -26,8 +26,8 @@ class UlyssesAttnWeight(AttnWeightTemplate):
torch.Tensor: 计算得到的注意力结果
"""
# 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
# 获取序列长度和文本相关的长度
seq_len = q.shape[0]
......@@ -48,9 +48,9 @@ class UlyssesAttnWeight(AttnWeightTemplate):
txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
# 将图像的查询、键和值转换为头的格式
img_q = all2all_seq2head(img_q)
img_k = all2all_seq2head(img_k)
img_v = all2all_seq2head(img_v)
img_q = all2all_seq2head(img_q, group=seq_p_group)
img_k = all2all_seq2head(img_k, group=seq_p_group)
img_v = all2all_seq2head(img_v, group=seq_p_group)
torch.cuda.synchronize() # 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头
......@@ -82,11 +82,11 @@ class UlyssesAttnWeight(AttnWeightTemplate):
# 收集所有进程的文本注意力结果
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn)
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
# 处理图像注意力结果
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
img_attn = all2all_head2seq(img_attn) # 将头的格式转换回序列格式
img_attn = all2all_head2seq(img_attn, group=seq_p_group) # 将头的格式转换回序列格式
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
torch.cuda.synchronize() # 确保CUDA操作完成
......
......@@ -4,7 +4,7 @@ import torch.distributed as dist
@dynamo.disable
def all2all_seq2head(input):
def all2all_seq2head(input, group=None):
"""
将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。
......@@ -18,7 +18,7 @@ def all2all_seq2head(input):
assert input.dim() == 3, f"input must be 3D tensor"
# 获取当前进程的世界大小
world_size = dist.get_world_size()
world_size = dist.get_world_size(group=group)
# 获取输入张量的形状
shard_seq_len, heads, hidden_dims = input.shape
......@@ -36,7 +36,7 @@ def all2all_seq2head(input):
output = torch.empty_like(input_t)
# 执行 all-to-all 操作,将输入张量的内容分发到所有进程
dist.all_to_all_single(output, input_t)
dist.all_to_all_single(output, input_t, group=group)
# 重塑输出张量为 [seq_len, heads/N, hidden_dims] 形状
output = output.reshape(seq_len, shard_heads, hidden_dims).contiguous()
......@@ -45,7 +45,7 @@ def all2all_seq2head(input):
@dynamo.disable
def all2all_head2seq(input):
def all2all_head2seq(input, group=None):
"""
将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。
......@@ -59,7 +59,7 @@ def all2all_head2seq(input):
assert input.dim() == 3, f"input must be 3D tensor"
# 获取当前进程的世界大小
world_size = dist.get_world_size()
world_size = dist.get_world_size(group=group)
# 获取输入张量的形状
seq_len, shard_heads, hidden_dims = input.shape
......@@ -78,7 +78,7 @@ def all2all_head2seq(input):
output = torch.empty_like(input_t)
# 执行 all-to-all 操作,将输入张量的内容分发到所有进程
dist.all_to_all_single(output, input_t)
dist.all_to_all_single(output, input_t, group=group)
# 重塑输出张量为 [heads, shard_seq_len, hidden_dims] 形状
output = output.reshape(heads, shard_seq_len, hidden_dims)
......
import argparse
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
import json
from lightx2v.utils.envs import *
......@@ -25,10 +26,15 @@ from loguru import logger
def init_runner(config):
seed_all(config.seed)
if config.parallel_attn_type:
if config.parallel:
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
cfg_p_size = config.parallel.get("cfg_p_size", 1)
seq_p_size = config.parallel.get("seq_p_size", 1)
assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
config["device_mesh"] = init_device_mesh("cuda", (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))
if CHECK_ENABLE_GRAPH_MODE():
default_runner = RUNNER_REGISTER[config.model_cls](config)
runner = GraphRunner(default_runner)
......
import torch
import math
from ..utils import compute_freqs, compute_freqs_causvid, compute_freqs_dist, apply_rotary_emb
from ..utils import compute_freqs, compute_freqs_causvid, apply_rotary_emb
from lightx2v.utils.envs import *
from ..transformer_infer import WanTransformerInfer
......
......@@ -2,12 +2,13 @@ import torch
from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer
import torch.distributed as dist
import torch.nn.functional as F
from lightx2v.models.networks.wan.infer.utils import compute_freqs_dist, compute_freqs_audio_dist
from lightx2v.models.networks.wan.infer.utils import pad_freqs
class WanTransformerDistInfer(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.seq_p_group = self.config["device_mesh"].get_group(mesh_dim="seq_p")
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
x = self.dist_pre_process(x)
......@@ -17,14 +18,14 @@ class WanTransformerDistInfer(WanTransformerInfer):
def compute_freqs(self, q, grid_sizes, freqs):
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
freqs_i = self.compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
freqs_i = self.compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
return freqs_i
def dist_pre_process(self, x):
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
padding_size = (world_size - (x.shape[0] % world_size)) % world_size
......@@ -36,16 +37,56 @@ class WanTransformerDistInfer(WanTransformerInfer):
return x
def dist_post_process(self, x):
# 获取当前进程的世界大小
world_size = dist.get_world_size()
world_size = dist.get_world_size(self.seq_p_group)
# 创建一个列表,用于存储所有进程的输出
gathered_x = [torch.empty_like(x) for _ in range(world_size)]
# 收集所有进程的输出
dist.all_gather(gathered_x, x)
dist.all_gather(gathered_x, x, group=self.seq_p_group)
# 在指定的维度上合并所有进程的输出
combined_output = torch.cat(gathered_x, dim=0)
return combined_output # 返回合并后的输出
def compute_freqs_dist(self, s, c, grid_sizes, freqs):
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_audio_dist(self, s, c, grid_sizes, freqs):
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
import torch
from .utils import compute_freqs, compute_freqs_dist, compute_freqs_audio, compute_freqs_audio_dist, apply_rotary_emb, apply_rotary_emb_chunk
from .utils import compute_freqs, compute_freqs_audio, apply_rotary_emb, apply_rotary_emb_chunk
from lightx2v.common.offload.manager import (
WeightAsyncStreamManager,
LazyWeightAsyncStreamManager,
......@@ -367,7 +367,7 @@ class WanTransformerInfer(BaseTransformerInfer):
del freqs_i, norm1_out, norm1_weight, norm1_bias
torch.cuda.empty_cache()
if self.config.get("parallel_attn_type", None):
if self.config.parallel and self.config.parallel.get("seq_p_size", False) and self.config.parallel.seq_p_size > 1:
attn_out = weights.self_attn_1_parallel.apply(
q=q,
k=k,
......@@ -375,6 +375,7 @@ class WanTransformerInfer(BaseTransformerInfer):
img_qkv_len=q.shape[0],
cu_seqlens_qkv=cu_seqlens_q,
attention_module=weights.self_attn_1,
seq_p_group=self.seq_p_group,
)
else:
attn_out = weights.self_attn_1.apply(
......
......@@ -37,28 +37,6 @@ def compute_freqs_audio(c, grid_sizes, freqs):
return freqs_i
def compute_freqs_audio_dist(s, c, grid_sizes, freqs):
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
......@@ -83,27 +61,6 @@ def pad_freqs(original_tensor, target_len):
return padded_tensor
def compute_freqs_dist(s, c, grid_sizes, freqs):
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def apply_rotary_emb(x, freqs_i):
n = x.size(1)
seq_len = freqs_i.size(0)
......
import os
import torch
import torch.distributed as dist
import glob
import json
from lightx2v.common.ops.attn import MaskMap
......@@ -69,7 +70,7 @@ class WanModel:
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer
if self.config.get("parallel_attn_type", None):
if self.config.parallel and self.config.parallel.get("seq_p_size", False) and self.config.parallel.seq_p_size > 1:
self.transformer_infer_class = WanTransformerDistInfer
else:
if self.config["feature_caching"] == "NoCaching":
......@@ -186,6 +187,10 @@ class WanModel:
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
if self.config.parallel and self.config.parallel.get("cfg_p_size", False) and self.config.parallel.cfg_p_size > 1:
self.infer = self.infer_with_cfg_parallel
else:
self.infer = self.infer_wo_cfg_parallel
def set_scheduler(self, scheduler):
self.scheduler = scheduler
......@@ -204,7 +209,7 @@ class WanModel:
self.transformer_weights.to_cuda()
@torch.no_grad()
def infer(self, inputs):
def infer_wo_cfg_parallel(self, inputs):
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
......@@ -245,6 +250,29 @@ class WanModel:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
@torch.no_grad()
def infer_with_cfg_parallel(self, inputs):
assert self.config["enable_cfg"], "enable_cfg must be True"
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
assert dist.get_world_size(cfg_p_group) == 2, f"cfg_p_world_size must be equal to 2"
cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
else:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)
noise_pred_cond = noise_pred_list[0] # cfg_p_rank == 0
noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
class Wan22MoeModel(WanModel):
def _load_ckpt(self, use_bf16, skip_bf16):
......
......@@ -191,8 +191,8 @@ class WanSelfAttention(WeightModule):
else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]())
if self.config.get("parallel_attn_type", None):
self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config["parallel_attn_type"]]())
if self.config.parallel and self.config.parallel.get("seq_p_size", False) and self.config.parallel.seq_p_size > 1:
self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config.parallel.get("seq_p_attn_type", "ulysses")]())
if self.quant_method in ["advanced_ptq"]:
self.add_module(
......
......@@ -43,7 +43,7 @@ class DefaultRunner(BaseRunner):
self.run_input_encoder = self._run_input_encoder_local_t2v
def set_init_device(self):
if self.config.parallel_attn_type:
if self.config.parallel:
cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank)
if self.config.cpu_offload:
......@@ -237,7 +237,7 @@ class DefaultRunner(BaseRunner):
else:
fps = self.config.get("fps", 16)
if not self.config.get("parallel_attn_type", None) or dist.get_rank() == 0:
if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(f"Saving video to {self.config.save_video_path}")
if self.config["model_cls"] != "wan2.2":
......
......@@ -124,7 +124,7 @@ class WanRunner(DefaultRunner):
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
"device": self.init_device,
"parallel": self.config.parallel_vae,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
"use_tiling": self.config.get("use_tiling_vae", False),
}
if self.config.task != "i2v":
......@@ -136,7 +136,7 @@ class WanRunner(DefaultRunner):
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
"device": self.init_device,
"parallel": self.config.parallel_vae,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
"use_tiling": self.config.get("use_tiling_vae", False),
}
if self.config.get("use_tiny_vae", False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment