Unverified Commit 3a2d2555 authored by STwangyingrui's avatar STwangyingrui Committed by GitHub
Browse files
parent b05b91a4
...@@ -4,7 +4,6 @@ import torch.distributed as dist ...@@ -4,7 +4,6 @@ import torch.distributed as dist
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate from .template import AttnWeightTemplate
from .utils.all2all import all2all_head2seq, all2all_seq2head
@ATTN_WEIGHT_REGISTER("ulysses") @ATTN_WEIGHT_REGISTER("ulysses")
...@@ -49,11 +48,60 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -49,11 +48,60 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous() img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous() txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
# 将图像的查询、键和值转换为头的格式 # 计算每个进程应该持有的头数分片
img_q = all2all_seq2head(img_q, group=seq_p_group) num_heads = img_q.shape[1]
img_k = all2all_seq2head(img_k, group=seq_p_group) shard_heads = num_heads // world_size
img_v = all2all_seq2head(img_v, group=seq_p_group)
torch.cuda.synchronize() # 确保CUDA操作完成 # 将 QKV 按头维度切分成 N 份,每份大小为 D/N
q_shards = [img_q[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
k_shards = [img_k[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
v_shards = [img_v[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
# 准备接收缓冲区
gathered_q_shards = [None] * world_size
gathered_k_shards = [None] * world_size
gathered_v_shards = [None] * world_size
for target_rank in range(world_size):
if target_rank != cur_rank:
gathered_q_shards[target_rank] = torch.empty_like(q_shards[target_rank])
gathered_k_shards[target_rank] = torch.empty_like(k_shards[target_rank])
gathered_v_shards[target_rank] = torch.empty_like(v_shards[target_rank])
else:
gathered_q_shards[cur_rank] = q_shards[cur_rank]
gathered_k_shards[cur_rank] = k_shards[cur_rank]
gathered_v_shards[cur_rank] = v_shards[cur_rank]
# 异步发起通信后同步
for target_rank in range(world_size):
if target_rank != cur_rank:
# 避免死锁: 按 rank 顺序决定发送/接收顺序
if cur_rank < target_rank:
sendq_req = dist.isend(q_shards[target_rank], dst=target_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_rank, group=seq_p_group)
sendv_req = dist.isend(v_shards[target_rank], dst=target_rank, group=seq_p_group)
recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_rank, group=seq_p_group)
recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_rank, group=seq_p_group)
recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_rank, group=seq_p_group)
else:
recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_rank, group=seq_p_group)
recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_rank, group=seq_p_group)
recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_rank, group=seq_p_group)
sendq_req = dist.isend(q_shards[target_rank], dst=target_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_rank, group=seq_p_group)
sendv_req = dist.isend(v_shards[target_rank], dst=target_rank, group=seq_p_group)
sendq_req.wait()
sendk_req.wait()
sendv_req.wait()
recvq_req.wait()
recvk_req.wait()
recvv_req.wait()
# 拼接所有分片 (在序列维度上)
# 每个 gathered_*_shards[i] 的形状是 (seq_len/N, num_heads/N, head_dim)
# 拼接后形状是 (seq_len, num_heads/N, head_dim)
img_q = torch.cat(gathered_q_shards, dim=0)
img_k = torch.cat(gathered_k_shards, dim=0)
img_v = torch.cat(gathered_v_shards, dim=0)
# 处理文本的查询、键和值,选择当前进程的头 # 处理文本的查询、键和值,选择当前进程的头
txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :] txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
...@@ -97,8 +145,36 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -97,8 +145,36 @@ class UlyssesAttnWeight(AttnWeightTemplate):
@torch.compiler.disable @torch.compiler.disable
def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group): def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group):
cur_rank = dist.get_rank(seq_p_group)
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果 img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
img_attn = all2all_head2seq(img_attn, group=seq_p_group) # 将头的格式转换回序列格式
# 按序列维度切分成 N 份
attn_shards = [img_attn[i * shard_seqlen : (i + 1) * shard_seqlen, :, :].contiguous() for i in range(world_size)]
# 准备接收缓冲区
gathered_attn_shards = [None] * world_size
for target_rank in range(world_size):
if target_rank != cur_rank:
gathered_attn_shards[target_rank] = torch.empty_like(attn_shards[target_rank])
else:
gathered_attn_shards[cur_rank] = attn_shards[cur_rank]
# 异步发起通信后同步
for target_rank in range(world_size):
if target_rank != cur_rank:
# 避免死锁: 按 rank 顺序决定发送/接收顺序
if cur_rank < target_rank:
send_req = dist.isend(attn_shards[target_rank], dst=target_rank, group=seq_p_group)
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_rank, group=seq_p_group)
else:
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_rank, group=seq_p_group)
send_req = dist.isend(attn_shards[target_rank], dst=target_rank, group=seq_p_group)
send_req.wait()
recv_req.wait()
# 拼接所有分片 (在头维度上)
img_attn = torch.cat(gathered_attn_shards, dim=1)
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状 img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
torch.cuda.synchronize() # 确保CUDA操作完成 torch.cuda.synchronize() # 确保CUDA操作完成
return img_attn return img_attn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment