"vscode:/vscode.git/clone" did not exist on "bcecf27e7ca265ec851746d84f5889e22b4b5813"
Unverified Commit b19d2ca2 authored by STwangyingrui's avatar STwangyingrui Committed by GitHub
Browse files
parent 1e42d3d3
......@@ -4,9 +4,9 @@
"video_duration": 5,
"audio_sr": 16000,
"target_video_length": 81,
"self_attn_1_type": "sage_attn3",
"cross_attn_1_type": "sage_attn3",
"cross_attn_2_type": "sage_attn3",
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
......@@ -14,16 +14,16 @@
"cpu_offload": false,
"offload_granularity": "block",
"offload_ratio": 1,
"t5_cpu_offload": false,
"t5_cpu_offload": true,
"t5_quantized": true,
"t5_quant_scheme": "int8-q8f",
"clip_cpu_offload": false,
"clip_cpu_offload": true,
"clip_quantized": false,
"audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false,
"audio_encoder_cpu_offload": true,
"audio_adapter_cpu_offload": true,
"adapter_quantized": true,
"adapter_quant_scheme": "int8-q8f",
"vae_cpu_offload": false,
"vae_cpu_offload": true,
"use_tiling_vae": false,
"dit_quantized": true,
"dit_quant_scheme": "int8-q8f",
......@@ -35,6 +35,6 @@
"parallel": {
"seq_p_size": 8,
"seq_p_fp8_comm": true,
"seq_p_attn_type": "ulysses"
"seq_p_attn_type": "ulysses-4090"
}
}
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
......@@ -140,6 +141,90 @@ class UlyssesAttnWeight(AttnWeightTemplate):
class Ulysses4090AttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
self.rounds = []
def generate_round_robin_pairs(self, seq_p_group=None):
"""
生成循环赛配对表,并确保每个配对中的第一个元素小于第二个
这样我们可以用简单的规则确定通信顺序
"""
cur_rank = dist.get_rank(seq_p_group)
world_size = dist.get_world_size(seq_p_group)
if world_size % 2 != 0:
raise ValueError("world_size必须是偶数,奇数情况需要特殊处理")
teams = list(range(world_size))
for _ in range(world_size - 1):
round_schedule = {}
for i in range(world_size // 2):
team1, team2 = teams[i], teams[world_size - 1 - i]
smaller, larger = min(team1, team2), max(team1, team2)
round_schedule[smaller] = (larger, True)
round_schedule[larger] = (smaller, False)
self.rounds.append(round_schedule)
# 旋转列表(固定第一个元素)
teams = [teams[0]] + [teams[-1]] + teams[1:-1]
# if cur_rank == 0:
# self.print_pairing_schedule(seq_p_group)
def print_pairing_schedule(self, seq_p_group):
"""打印通信调度表"""
world_size = dist.get_world_size(seq_p_group)
logger.info("循环赛通信调度表:")
logger.info("=" * 50)
for i, round_schedule in enumerate(self.rounds):
logger.info(f"第 {i + 1} 轮:")
for cur_rank in range(world_size):
partner, is_smaller_in_pair = round_schedule[cur_rank]
logger.info(f" 进程 {cur_rank} ←→ 进程 {partner}")
logger.info("=" * 50)
def load_balanced_all_to_all(self, shards, seq_p_group=None):
"""
负载均衡all-to-all通信实现
"""
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
global_rank = dist.get_global_rank(seq_p_group, cur_rank)
cfg_p_group_index = global_rank // world_size
# 准备接收缓冲区
gathered_shards = [None] * world_size
for target_rank in range(world_size):
if target_rank != cur_rank:
gathered_shards[target_rank] = torch.empty_like(shards[target_rank])
else:
gathered_shards[cur_rank] = shards[cur_rank]
for i, round_schedule in enumerate(self.rounds):
# 查找当前进程在本轮的配对
partner = None
is_smaller_in_pair = False
if cur_rank in round_schedule:
partner, is_smaller_in_pair = round_schedule[cur_rank]
# 如果没有找到配对,说明本轮当前进程空闲
if partner is None:
continue
# 计算全局rank
partner_global_rank = cfg_p_group_index * world_size + partner
if is_smaller_in_pair:
# 当前进程是配对中的较小者,先发送后接收
send_req = dist.isend(shards[partner], dst=partner_global_rank, group=seq_p_group)
recv_req = dist.irecv(gathered_shards[partner], src=partner_global_rank, group=seq_p_group)
send_req.wait()
recv_req.wait()
else:
# 当前进程是配对中的较大者,先接收后发送
recv_req = dist.irecv(gathered_shards[partner], src=partner_global_rank, group=seq_p_group)
send_req = dist.isend(shards[partner], dst=partner_global_rank, group=seq_p_group)
recv_req.wait()
send_req.wait()
return gathered_shards
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
"""
......@@ -156,6 +241,9 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
返回:
torch.Tensor: 计算得到的注意力结果
"""
if len(self.rounds) == 0:
self.generate_round_robin_pairs(seq_p_group)
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
......@@ -189,150 +277,50 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
num_heads = img_q.shape[1]
shard_heads = num_heads // world_size
# 将 QKV 按头维度切分成 N 份,每份大小为 D/N
q_shards = [img_q[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
k_shards = [img_k[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
v_shards = [img_v[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
# 将 image QKV 拼接后,按头维度切分成 N 份,每份大小为 D/N
img_qkv = torch.stack([img_q, img_k, img_v], dim=0)
qkv_shards = [img_qkv[:, :, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
qkv_dtype = img_qkv.dtype
if use_fp8_comm:
original_dtype = img_q.dtype
original_shape = img_q.shape
# 量化所有分片
q_shards_fp8 = []
q_scales = []
k_shards_fp8 = []
k_scales = []
v_shards_fp8 = []
v_scales = []
qkv_fp8_byte_tensors = []
qkv_fp8_bytes = 0
qkv_fp8_dtype = None
qkv_scale_dtype = None
for i in range(world_size):
q_fp8, q_scale = quant_fp8_vllm(q_shards[i].reshape(-1, original_shape[-1]))
q_shards_fp8.append(q_fp8)
q_scales.append(q_scale)
k_fp8, k_scale = quant_fp8_vllm(k_shards[i].reshape(-1, original_shape[-1]))
k_shards_fp8.append(k_fp8)
k_scales.append(k_scale)
v_fp8, v_scale = quant_fp8_vllm(v_shards[i].reshape(-1, original_shape[-1]))
v_shards_fp8.append(v_fp8)
v_scales.append(v_scale)
# 准备接收缓冲区 (FP8 + scale)
gathered_q_fp8 = [None] * world_size
gathered_q_scales = [None] * world_size
gathered_k_fp8 = [None] * world_size
gathered_k_scales = [None] * world_size
gathered_v_fp8 = [None] * world_size
gathered_v_scales = [None] * world_size
for target_rank in range(world_size):
if target_rank != cur_rank:
gathered_q_fp8[target_rank] = torch.empty_like(q_shards_fp8[target_rank])
gathered_q_scales[target_rank] = torch.empty_like(q_scales[target_rank])
gathered_k_fp8[target_rank] = torch.empty_like(k_shards_fp8[target_rank])
gathered_k_scales[target_rank] = torch.empty_like(k_scales[target_rank])
gathered_v_fp8[target_rank] = torch.empty_like(v_shards_fp8[target_rank])
gathered_v_scales[target_rank] = torch.empty_like(v_scales[target_rank])
else:
gathered_q_fp8[cur_rank] = q_shards_fp8[cur_rank]
gathered_q_scales[cur_rank] = q_scales[cur_rank]
gathered_k_fp8[cur_rank] = k_shards_fp8[cur_rank]
gathered_k_scales[cur_rank] = k_scales[cur_rank]
gathered_v_fp8[cur_rank] = v_shards_fp8[cur_rank]
gathered_v_scales[cur_rank] = v_scales[cur_rank]
# 通信 FP8 数据和 scales
for target_rank in range(world_size):
target_global_rank = cfg_p_group_index * world_size + target_rank
if target_rank != cur_rank:
if cur_rank < target_rank:
sendq_fp8_req = dist.isend(q_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_fp8_req = dist.isend(k_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_fp8_req = dist.isend(v_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
sendq_scale_req = dist.isend(q_scales[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_scale_req = dist.isend(k_scales[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_scale_req = dist.isend(v_scales[target_rank], dst=target_global_rank, group=seq_p_group)
recvq_fp8_req = dist.irecv(gathered_q_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recvk_fp8_req = dist.irecv(gathered_k_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recvv_fp8_req = dist.irecv(gathered_v_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recvq_scale_req = dist.irecv(gathered_q_scales[target_rank], src=target_global_rank, group=seq_p_group)
recvk_scale_req = dist.irecv(gathered_k_scales[target_rank], src=target_global_rank, group=seq_p_group)
recvv_scale_req = dist.irecv(gathered_v_scales[target_rank], src=target_global_rank, group=seq_p_group)
else:
recvq_fp8_req = dist.irecv(gathered_q_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recvk_fp8_req = dist.irecv(gathered_k_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recvv_fp8_req = dist.irecv(gathered_v_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recvq_scale_req = dist.irecv(gathered_q_scales[target_rank], src=target_global_rank, group=seq_p_group)
recvk_scale_req = dist.irecv(gathered_k_scales[target_rank], src=target_global_rank, group=seq_p_group)
recvv_scale_req = dist.irecv(gathered_v_scales[target_rank], src=target_global_rank, group=seq_p_group)
sendq_fp8_req = dist.isend(q_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_fp8_req = dist.isend(k_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_fp8_req = dist.isend(v_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
sendq_scale_req = dist.isend(q_scales[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_scale_req = dist.isend(k_scales[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_scale_req = dist.isend(v_scales[target_rank], dst=target_global_rank, group=seq_p_group)
sendq_fp8_req.wait()
sendk_fp8_req.wait()
sendv_fp8_req.wait()
sendq_scale_req.wait()
sendk_scale_req.wait()
sendv_scale_req.wait()
recvq_fp8_req.wait()
recvk_fp8_req.wait()
recvv_fp8_req.wait()
recvq_scale_req.wait()
recvk_scale_req.wait()
recvv_scale_req.wait()
# 反量化
qkv_fp8, qkv_scale = quant_fp8_vllm(qkv_shards[i].reshape(-1, hidden_dims))
if i == 0:
qkv_fp8_bytes = qkv_fp8.numel() * qkv_fp8.element_size()
qkv_fp8_dtype = qkv_fp8.dtype
qkv_scale_dtype = qkv_scale.dtype
qkv_fp8_byte_tensors.append(torch.cat([qkv_fp8.contiguous().reshape(-1).view(torch.uint8), qkv_scale.contiguous().reshape(-1).view(torch.uint8)], dim=0))
gathered_qkv_fp8_byte_tensors = self.load_balanced_all_to_all(qkv_fp8_byte_tensors, seq_p_group)
gathered_q_shards = []
gathered_k_shards = []
gathered_v_shards = []
for i in range(world_size):
q_shards_new = dequant_fp8_vllm(gathered_q_fp8[i], gathered_q_scales[i], original_dtype).reshape(-1, shard_heads, hidden_dims)
k_shards_new = dequant_fp8_vllm(gathered_k_fp8[i], gathered_k_scales[i], original_dtype).reshape(-1, shard_heads, hidden_dims)
v_shards_new = dequant_fp8_vllm(gathered_v_fp8[i], gathered_v_scales[i], original_dtype).reshape(-1, shard_heads, hidden_dims)
qkv_fp8_byte_tensor = gathered_qkv_fp8_byte_tensors[i]
qkv_fp8 = qkv_fp8_byte_tensor[:qkv_fp8_bytes].view(qkv_fp8_dtype).reshape(3, -1, hidden_dims)
qkv_scale = qkv_fp8_byte_tensor[qkv_fp8_bytes:].view(qkv_scale_dtype).reshape(3, -1, 1)
q_shards_new = dequant_fp8_vllm(qkv_fp8[0], qkv_scale[0], qkv_dtype).reshape(-1, shard_heads, hidden_dims)
k_shards_new = dequant_fp8_vllm(qkv_fp8[1], qkv_scale[1], qkv_dtype).reshape(-1, shard_heads, hidden_dims)
v_shards_new = dequant_fp8_vllm(qkv_fp8[2], qkv_scale[2], qkv_dtype).reshape(-1, shard_heads, hidden_dims)
gathered_q_shards.append(q_shards_new)
gathered_k_shards.append(k_shards_new)
gathered_v_shards.append(v_shards_new)
else:
# 准备接收缓冲区
gathered_q_shards = [None] * world_size
gathered_k_shards = [None] * world_size
gathered_v_shards = [None] * world_size
for target_rank in range(world_size):
if target_rank != cur_rank:
gathered_q_shards[target_rank] = torch.empty_like(q_shards[target_rank])
gathered_k_shards[target_rank] = torch.empty_like(k_shards[target_rank])
gathered_v_shards[target_rank] = torch.empty_like(v_shards[target_rank])
else:
gathered_q_shards[cur_rank] = q_shards[cur_rank]
gathered_k_shards[cur_rank] = k_shards[cur_rank]
gathered_v_shards[cur_rank] = v_shards[cur_rank]
# 通信
for target_rank in range(world_size):
target_global_rank = cfg_p_group_index * world_size + target_rank
if target_rank != cur_rank:
if cur_rank < target_rank:
sendq_req = dist.isend(q_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_req = dist.isend(v_shards[target_rank], dst=target_global_rank, group=seq_p_group)
recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_global_rank, group=seq_p_group)
recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_global_rank, group=seq_p_group)
recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_global_rank, group=seq_p_group)
else:
recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_global_rank, group=seq_p_group)
recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_global_rank, group=seq_p_group)
recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_global_rank, group=seq_p_group)
sendq_req = dist.isend(q_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_req = dist.isend(v_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendq_req.wait()
sendk_req.wait()
sendv_req.wait()
recvq_req.wait()
recvk_req.wait()
recvv_req.wait()
gathered_qkv_byte_tensors = self.load_balanced_all_to_all(qkv_shards, seq_p_group)
gathered_q_shards = []
gathered_k_shards = []
gathered_v_shards = []
for i in range(world_size):
qkv_tensor = gathered_qkv_byte_tensors[i].view(qkv_dtype).reshape(3, -1, shard_heads, hidden_dims)
gathered_q_shards.append(qkv_tensor[0])
gathered_k_shards.append(qkv_tensor[1])
gathered_v_shards.append(qkv_tensor[2])
# 拼接所有分片 (在序列维度上)
# 每个 gathered_*_shards[i] 的形状是 (seq_len/N, num_heads/N, head_dim)
......@@ -389,78 +377,36 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
cfg_p_group_index = global_rank // world_size
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
attn_dtype = img_attn.dtype
# 按序列维度切分成 N 份
attn_shards = [img_attn[i * shard_seqlen : (i + 1) * shard_seqlen, :, :].contiguous() for i in range(world_size)]
if use_fp8_comm:
original_dtype = img_attn.dtype
original_shape = img_attn.shape
# 量化所有分片
attn_shards_fp8 = []
attn_scales = []
attn_fp8_byte_tensors = []
attn_fp8_bytes = 0
attn_fp8_dtype = None
attn_scale_dtype = None
for i in range(world_size):
attn_fp8, attn_scale = quant_fp8_vllm(attn_shards[i].reshape(-1, original_shape[-1]))
attn_shards_fp8.append(attn_fp8)
attn_scales.append(attn_scale)
# 准备接收缓冲区 (FP8 + scale)
gathered_attn_fp8 = [None] * world_size
gathered_attn_scales = [None] * world_size
for target_rank in range(world_size):
if target_rank != cur_rank:
gathered_attn_fp8[target_rank] = torch.empty_like(attn_shards_fp8[target_rank])
gathered_attn_scales[target_rank] = torch.empty_like(attn_scales[target_rank])
else:
gathered_attn_fp8[cur_rank] = attn_shards_fp8[cur_rank]
gathered_attn_scales[cur_rank] = attn_scales[cur_rank]
# 通信 FP8 数据和 scales
for target_rank in range(world_size):
target_global_rank = cfg_p_group_index * world_size + target_rank
if target_rank != cur_rank:
if cur_rank < target_rank:
send_attn_fp8_req = dist.isend(attn_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
send_attn_scale_req = dist.isend(attn_scales[target_rank], dst=target_global_rank, group=seq_p_group)
recv_attn_fp8_req = dist.irecv(gathered_attn_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recv_attn_scale_req = dist.irecv(gathered_attn_scales[target_rank], src=target_global_rank, group=seq_p_group)
else:
recv_attn_fp8_req = dist.irecv(gathered_attn_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recv_attn_scale_req = dist.irecv(gathered_attn_scales[target_rank], src=target_global_rank, group=seq_p_group)
send_attn_fp8_req = dist.isend(attn_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
send_attn_scale_req = dist.isend(attn_scales[target_rank], dst=target_global_rank, group=seq_p_group)
send_attn_fp8_req.wait()
send_attn_scale_req.wait()
recv_attn_fp8_req.wait()
recv_attn_scale_req.wait()
# 反量化
attn_fp8, attn_scale = quant_fp8_vllm(attn_shards[i].reshape(-1, hidden_dims))
if i == 0:
attn_fp8_bytes = attn_fp8.numel() * attn_fp8.element_size()
attn_fp8_dtype = attn_fp8.dtype
attn_scale_dtype = attn_scale.dtype
attn_fp8_byte_tensors.append(torch.cat([attn_fp8.contiguous().reshape(-1).view(torch.uint8), attn_scale.contiguous().reshape(-1).view(torch.uint8)], dim=0))
gathered_attn_fp8_byte_tensors = self.load_balanced_all_to_all(attn_fp8_byte_tensors, seq_p_group)
gathered_attn_shards = []
for i in range(world_size):
attn_shards_new = dequant_fp8_vllm(gathered_attn_fp8[i], gathered_attn_scales[i], original_dtype).reshape(-1, shard_heads, hidden_dims)
attn_fp8_byte_tensor = gathered_attn_fp8_byte_tensors[i]
attn_fp8 = attn_fp8_byte_tensor[:attn_fp8_bytes].view(attn_fp8_dtype).reshape(-1, hidden_dims)
attn_scale = attn_fp8_byte_tensor[attn_fp8_bytes:].view(attn_scale_dtype).reshape(-1, 1)
attn_shards_new = dequant_fp8_vllm(attn_fp8, attn_scale, attn_dtype).reshape(-1, shard_heads, hidden_dims)
gathered_attn_shards.append(attn_shards_new)
else:
# 准备接收缓冲区
gathered_attn_shards = [None] * world_size
for target_rank in range(world_size):
if target_rank != cur_rank:
gathered_attn_shards[target_rank] = torch.empty_like(attn_shards[target_rank])
else:
gathered_attn_shards[cur_rank] = attn_shards[cur_rank]
# 通信
for target_rank in range(world_size):
target_global_rank = cfg_p_group_index * world_size + target_rank
if target_rank != cur_rank:
if cur_rank < target_rank:
send_req = dist.isend(attn_shards[target_rank], dst=target_global_rank, group=seq_p_group)
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_global_rank, group=seq_p_group)
else:
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_global_rank, group=seq_p_group)
send_req = dist.isend(attn_shards[target_rank], dst=target_global_rank, group=seq_p_group)
send_req.wait()
recv_req.wait()
gathered_attn_shards = self.load_balanced_all_to_all(attn_shards, seq_p_group)
# 拼接所有分片 (在头维度上)
img_attn = torch.cat(gathered_attn_shards, dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment