Commit 1a881d63 authored by helloyongyang's avatar helloyongyang
Browse files

重构并行模块

parent 18e2b23a
from lightx2v.attentions.common.torch_sdpa import torch_sdpa
from lightx2v.attentions.common.flash_attn2 import flash_attn2
from lightx2v.attentions.common.flash_attn3 import flash_attn3
from lightx2v.attentions.common.sage_attn2 import sage_attn2
from lightx2v.attentions.common.radial_attn import radial_attn
def attention(attention_type="flash_attn2", *args, **kwargs):
if attention_type == "torch_sdpa":
return torch_sdpa(*args, **kwargs)
elif attention_type == "flash_attn2":
return flash_attn2(*args, **kwargs)
elif attention_type == "flash_attn3":
return flash_attn3(*args, **kwargs)
elif attention_type == "sage_attn2":
return sage_attn2(*args, **kwargs)
elif attention_type == "radial_attn":
return radial_attn(*args, **kwargs)
else:
raise NotImplementedError(f"Unsupported attention mode: {attention_type}")
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
def flash_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
x = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(max_seqlen_q, -1)
return x
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
flash_attn_varlen_func_v3 = None
def flash_attn3(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)[0].reshape(max_seqlen_q, -1)
return x
import torch
if torch.cuda.get_device_capability(0) == (8, 9):
try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
sageattn = None, None
else:
try:
from sageattention import sageattn
except ImportError:
sageattn = None
def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls="hunyuan"):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if model_cls == "hunyuan":
x1 = sageattn(
q[: cu_seqlens_q[1]].unsqueeze(0),
k[: cu_seqlens_kv[1]].unsqueeze(0),
v[: cu_seqlens_kv[1]].unsqueeze(0),
tensor_layout="NHD",
)
x2 = sageattn(
q[cu_seqlens_q[1] :].unsqueeze(0),
k[cu_seqlens_kv[1] :].unsqueeze(0),
v[cu_seqlens_kv[1] :].unsqueeze(0),
tensor_layout="NHD",
)
x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df"]:
x = sageattn(
q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
tensor_layout="NHD",
)
x = x.view(max_seqlen_q, -1)
return x
import torch
import torch.nn.functional as F
def torch_sdpa(
q,
k,
v,
drop_rate=0,
attn_mask=None,
causal=False,
):
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
x = x.transpose(1, 2)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
import torch
import torch.distributed as dist
from lightx2v.attentions import attention
def partial_heads_attn(attention_type, q, k, v, cu_seqlens_qkv, max_seqlen_qkv):
num_heads = q.shape[-2]
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
num_chunk_heads = int(num_heads / dist.get_world_size())
if cur_rank == world_size - 1:
q = q[:, num_chunk_heads * cur_rank :, :]
k = k[:, num_chunk_heads * cur_rank :, :]
v = v[:, num_chunk_heads * cur_rank :, :]
else:
q = q[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
k = k[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
v = v[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
output = attention(
attention_type=attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
gathered_outputs = [torch.empty_like(output) for _ in range(world_size)]
dist.all_gather(gathered_outputs, output)
combined_output = torch.cat(gathered_outputs, dim=1)
return combined_output
export PYTHONPATH=/workspace/lightx2v:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1
torchrun --nproc_per_node=2 test_acc.py
import torch
import torch.distributed as dist
from lightx2v.attentions import attention
from lightx2v.utils.utils import seed_all
from loguru import logger
seed_all(42)
def prepare_tensors():
cur_rank = dist.get_rank() # 获取当前进程的 rank
torch.cuda.set_device(cur_rank) # 设置当前进程的 CUDA 设备
q = torch.randn(32656, 24, 128, dtype=torch.bfloat16).cuda()
k = torch.randn(32656, 24, 128, dtype=torch.bfloat16).cuda()
v = torch.randn(32656, 24, 128, dtype=torch.bfloat16).cuda()
cu_seqlens_qkv = torch.tensor([0, 32411, 32656], dtype=torch.int32).cuda()
max_seqlen_qkv = 32656
return q, k, v, cu_seqlens_qkv, max_seqlen_qkv
def test_part_head():
q, k, v, cu_seqlens_qkv, max_seqlen_qkv = prepare_tensors()
# 先计算完整的结果作为参考
single_gpu_output = attention(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
num_heads = q.shape[-2]
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
num_chunk_heads = int(num_heads / dist.get_world_size())
if cur_rank == world_size - 1:
q = q[:, num_chunk_heads * cur_rank :, :]
k = k[:, num_chunk_heads * cur_rank :, :]
v = v[:, num_chunk_heads * cur_rank :, :]
else:
q = q[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
k = k[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
v = v[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
output = attention(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
gathered_outputs = [torch.empty_like(output) for _ in range(world_size)]
dist.all_gather(gathered_outputs, output)
combined_output = torch.cat(gathered_outputs, dim=1)
# 验证结果一致性
if cur_rank == 0:
# import pdb; pdb.set_trace()
logger.info("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
# # 验证结果一致性
# logger.info("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
if __name__ == "__main__":
# 初始化分布式环境
dist.init_process_group(backend="nccl")
test_part_head()
from lightx2v.attentions.distributed.partial_heads_attn.attn import partial_heads_attn
def parallelize_hunyuan(hunyuan_model):
hunyuan_model.transformer_infer.parallel_attention = partial_heads_attn
import torch
import torch.distributed as dist
import torch.nn.functional as F
# from lightx2v.attentions import attention
from lightx2v.attentions.distributed.comm.ring_comm import RingComm
try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
except ImportError:
flash_attn = None
_flash_attn_forward = None
from typing import Optional, Tuple
# RING_COMM = None
# def init_ring_comm():
# global RING_COMM
# RING_COMM = RingComm()
@torch.jit.script
def _update_out_and_lse(
out: torch.Tensor,
lse: torch.Tensor,
block_out: torch.Tensor,
block_lse: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
block_out = block_out.to(torch.float32)
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
# new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
# torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
# For additional context and discussion, please refer to:
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)
return out, lse
def update_out_and_lse(
out: Optional[torch.Tensor],
lse: Optional[torch.Tensor],
block_out: torch.Tensor,
block_lse: torch.Tensor,
slice_=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if out is None:
if slice_ is not None:
raise RuntimeError("first update_out_and_lse should not pass slice_ args")
out = block_out.to(torch.float32)
lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
elif slice_ is not None:
slice_out, slice_lse = out[slice_], lse[slice_]
slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse)
out[slice_], lse[slice_] = slice_out, slice_lse
else:
out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
return out, lse
def ring_attn_sub(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if flash_attn.__version__ < "2.6.3":
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
q,
k,
v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax,
)
else:
block_out, block_lse, _, _ = _flash_attn_forward(
q,
k,
v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax,
)
return block_out, block_lse
def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"):
"""
执行 Ring 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
# 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = 0
# if RING_COMM is None:
# init_ring_comm()
RING_COMM = RingComm()
# if len(cu_seqlens_qkv) == 3:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
# elif len(cu_seqlens_qkv) == 2:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = None
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
img_q, img_k, img_v = q[:, :img_qkv_len, :, :].contiguous(), k[:, :img_qkv_len, :, :].contiguous(), v[:, :img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = (
q[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
k[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
v[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
)
out, lse, next_k, next_v = None, None, None, None
if len(cu_seqlens_qkv) == 3:
q = torch.cat((img_q, txt_q), dim=1)
k = img_k
v = img_v
for step in range(world_size):
if step + 1 != world_size:
next_k = RING_COMM.send_recv(k)
next_v = RING_COMM.send_recv(v)
RING_COMM.commit()
if step + 1 == world_size:
k = torch.cat((k, txt_k), dim=1)
v = torch.cat((v, txt_v), dim=1)
block_out, block_lse = ring_attn_sub(q, k, v)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != world_size:
RING_COMM.wait()
k = next_k
v = next_v
attn1 = out.to(torch.bfloat16).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1)
if txt_mask_len > 0:
attn2, *_ = _flash_attn_forward(
q[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
k[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
v[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
dropout_p=0.0,
softmax_scale=q.shape[-1] ** (-0.5),
causal=False,
window_size_left=-1,
window_size_right=-1,
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
)
attn2 = attn2.to(torch.bfloat16).squeeze(0).reshape((txt_mask_len - txt_qkv_len), -1)
attn1 = torch.cat([attn1, attn2], dim=0)
return attn1
import torch
import torch.distributed as dist
from lightx2v.attentions import attention
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
from lightx2v.attentions.distributed.ring.attn import ring_attn_sub, update_out_and_lse
from lightx2v.attentions.distributed.comm.ring_comm import RingComm
RING_COMM = None
def init_ring_comm():
global RING_COMM
RING_COMM = RingComm()
def base_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, lq, lk):
attn_out = attention(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
)
return attn_out
def ring_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, lq, lk, ring_size):
out, lse = None, None
# q = torch.chunk(q, ring_size)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
k = torch.chunk(k, ring_size, dim=1)
v = torch.chunk(v, ring_size, dim=1)
for i in range(ring_size):
k_block, v_block = k[i], v[i]
block_out, block_lse = ring_attn_sub(q, k_block, v_block)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
attn_out = out.to(torch.bfloat16).squeeze(0).reshape(lq, -1)
return attn_out
def ring_attention_dist(q, k, v, cu_seqlens_q, cu_seqlens_k, lq, lk):
if RING_COMM is None:
init_ring_comm()
out, lse = None, None
# q = torch.chunk(q, ring_size)
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
out, lse, next_k, next_v = None, None, None, None
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
k = torch.chunk(k, world_size, dim=1)[cur_rank]
v = torch.chunk(v, world_size, dim=1)[cur_rank]
for step in range(world_size):
if step + 1 != world_size:
next_k = RING_COMM.send_recv(k)
next_v = RING_COMM.send_recv(v)
RING_COMM.commit()
block_out, block_lse = ring_attn_sub(q, k, v)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != world_size:
RING_COMM.wait()
k = next_k
v = next_v
attn_out = out.to(torch.bfloat16).squeeze(0).reshape(lq, -1)
return attn_out
def test():
q = torch.randn((32760, 12, 128), dtype=torch.bfloat16, device="cuda")
k = torch.randn((32760, 12, 128), dtype=torch.bfloat16, device="cuda")
v = torch.randn((32760, 12, 128), dtype=torch.bfloat16, device="cuda")
cu_seqlens_q = torch.tensor([0, 32760], dtype=torch.int32, device="cuda")
cu_seqlens_k = torch.tensor([0, 32760], dtype=torch.int32, device="cuda")
lq = 32760
lk = 32760
base_attn = base_attention(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, lq=lq, lk=lk)
ring_attn = ring_attention(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, lq=lq, lk=lk, ring_size=4)
# import pdb; pdb.set_trace()
# 添加断言以确认数值相同
assert torch.allclose(base_attn, ring_attn, rtol=1e-3, atol=1e-3), "base_attn 和 ring_attn 的数值不相同!"
if __name__ == "__main__":
# dist.init_process_group(backend="nccl")
test()
lightx2v_path=""
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
python3 test.py
import functools
from lightx2v.attentions.distributed.ring.attn import ring_attn
def parallelize_hunyuan(hunyuan_model):
from lightx2v.attentions.distributed.utils.hunyuan.processor import pre_process, post_process
"""将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。
参数:
hunyuan_model: Hunyuan 模型实例,包含推理方法和其他属性。
"""
# 将 Hunyuan 模型的并行注意力机制替换为 Ulysses 注意力
hunyuan_model.transformer_infer.parallel_attention = ring_attn
# 保存原始的推理方法,以便后续调用
original_infer = hunyuan_model.infer
@functools.wraps(hunyuan_model.__class__.infer) # 保留原始推理方法的元信息
def new_infer(self, text_encoders_output, image_encoder_output, args):
"""新的推理方法,处理输入并调用原始推理方法。
参数:
self: Hunyuan 模型实例
text_encoders_output: 文本编码器的输出
args: 其他参数
返回:
None
"""
# 保存原始的潜在模型输入和频率数据
self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin = (self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin)
# 预处理输入数据以适应并行计算
self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin, split_dim = pre_process(self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin)
# 调用原始推理方法,获取输出
original_infer(text_encoders_output, image_encoder_output, args)
# 对输出进行后处理
self.scheduler.noise_pred = post_process(self.scheduler.noise_pred, split_dim)
# 恢复原始的潜在模型输入和频率数据
self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin = (self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin)
# return combined_output # 返回处理后的输出(当前被注释掉)
# 将新的推理方法绑定到 Hunyuan 模型实例
new_infer = new_infer.__get__(hunyuan_model)
hunyuan_model.infer = new_infer # 替换原始推理方法
def parallelize_wan(wan_model):
from lightx2v.attentions.distributed.utils.wan.processor import pre_process, post_process
wan_model.transformer_infer.parallel_attention = ring_attn
original_infer = wan_model.transformer_infer.infer
@functools.wraps(wan_model.transformer_infer.__class__.infer) # 保留原始推理方法的元信息
def new_infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
x = pre_process(x)
x = original_infer(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
x = post_process(x)
return x
new_infer = new_infer.__get__(wan_model.transformer_infer)
wan_model.transformer_infer.infer = new_infer # 替换原始推理方法
import torch
import torch.distributed as dist
from lightx2v.attentions import attention
from lightx2v.attentions.distributed.comm.all2all import all2all_seq2head, all2all_head2seq
def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
# 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
# 获取序列长度和文本相关的长度
seq_len = q.shape[0]
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = None
# 获取查询张量的头数和隐藏维度
_, heads, hidden_dims = q.shape
shard_heads = heads // world_size # 每个进程处理的头数
shard_seqlen = img_qkv_len # 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
# 将图像的查询、键和值转换为头的格式
img_q = all2all_seq2head(img_q)
img_k = all2all_seq2head(img_k)
img_v = all2all_seq2head(img_v)
torch.cuda.synchronize() # 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头
txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
# 合并图像和文本的查询、键和值
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
v = torch.cat((img_v, txt_v), dim=0)
# 初始化累积序列长度张量
cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device="cuda")
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置
cu_seqlens_qkv[1] = s1 # 设置累积序列长度
if txt_mask_len:
s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
cu_seqlens_qkv = torch.cat(cu_seqlens_qkv, s2)
max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
# 调用注意力函数计算注意力结果
attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
# 分割图像和文本的注意力结果
img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,]
# 收集所有进程的文本注意力结果
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn)
# 处理图像注意力结果
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
img_attn = all2all_head2seq(img_attn) # 将头的格式转换回序列格式
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
torch.cuda.synchronize() # 确保CUDA操作完成
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
attn = torch.cat([img_attn, txt_attn], dim=0)
return attn # 返回最终的注意力结果
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment