attn.py 4.32 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

import torch
import torch.distributed as dist
from lightx2v.attentions import attention
from lightx2v.attentions.distributed.comm.all2all import all2all_seq2head, all2all_head2seq


def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"):
    '''
    执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。

    参数:
        q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
        k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
        v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
        img_qkv_len (int): 图像查询、键和值的长度
        cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
        attention_type (str): 注意力类型,默认为 "flash_attn2"

    返回:
        torch.Tensor: 计算得到的注意力结果
    '''
    # 获取当前进程的排名和全局进程数
    cur_rank = dist.get_rank()
    world_size = dist.get_world_size()
    
    # 获取序列长度和文本相关的长度
    seq_len = q.shape[0]
Xinchi Huang's avatar
Xinchi Huang committed
29
30
31
32
33
34
    if len(cu_seqlens_qkv) == 3:
        txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len  # 文本查询、键和值的长度
        txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len  # 文本掩码长度
    elif len(cu_seqlens_qkv) == 2:
        txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len  # 文本查询、键和值的长度
        txt_mask_len = None
helloyongyang's avatar
helloyongyang committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    
    # 获取查询张量的头数和隐藏维度
    _, heads, hidden_dims = q.shape
    shard_heads = heads // world_size  # 每个进程处理的头数
    shard_seqlen = img_qkv_len  # 每个进程处理的序列长度

    # 分割图像和文本的查询、键和值
    img_q, img_k, img_v = q[:img_qkv_len,:,:].contiguous(), k[:img_qkv_len,:,:].contiguous(), v[:img_qkv_len,:,:].contiguous()
    txt_q, txt_k, txt_v = q[img_qkv_len:,:,:].contiguous(), k[img_qkv_len:,:,:].contiguous(), v[img_qkv_len:,:,:].contiguous()

    # 将图像的查询、键和值转换为头的格式
    img_q = all2all_seq2head(img_q)
    img_k = all2all_seq2head(img_k)
    img_v = all2all_seq2head(img_v)
    torch.cuda.synchronize()  # 确保CUDA操作完成

    # 处理文本的查询、键和值,选择当前进程的头
    txt_q = txt_q[:,cur_rank*shard_heads:(cur_rank+1)*shard_heads,:]
    txt_k = txt_k[:,cur_rank*shard_heads:(cur_rank+1)*shard_heads,:]
    txt_v = txt_v[:,cur_rank*shard_heads:(cur_rank+1)*shard_heads,:]

    # 合并图像和文本的查询、键和值
    q = torch.cat((img_q, txt_q), dim=0)
    k = torch.cat((img_k, txt_k), dim=0)
    v = torch.cat((img_v, txt_v), dim=0)

    # 初始化累积序列长度张量
    cu_seqlens_qkv = torch.zeros([3], dtype=torch.int32, device="cuda")
    s = txt_qkv_len + img_q.shape[0]  # 计算文本和图像的总长度
    s1 = s  # 当前样本的结束位置
    cu_seqlens_qkv[1] = s1  # 设置累积序列长度
Xinchi Huang's avatar
Xinchi Huang committed
66
67
68
    if txt_mask_len:
        s2 = txt_mask_len + img_q.shape[0]  # 文本掩码的结束位置
        cu_seqlens_qkv[2] = s2  # 设置累积序列长度
helloyongyang's avatar
helloyongyang committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0]  # 最大序列长度

    # 调用注意力函数计算注意力结果
    attn = attention(
        attention_type=attention_type,
        q=q,
        k=k,
        v=v,
        cu_seqlens_q=cu_seqlens_qkv,
        cu_seqlens_kv=cu_seqlens_qkv,
        max_seqlen_q=max_seqlen_qkv,
        max_seqlen_kv=max_seqlen_qkv
    )

    # 分割图像和文本的注意力结果
    img_attn, txt_attn = attn[:img_q.shape[0],:], attn[img_q.shape[0]:,]

    # 收集所有进程的文本注意力结果
    gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
    dist.all_gather(gathered_txt_attn, txt_attn)

    # 处理图像注意力结果
    img_attn = img_attn.reshape(world_size*shard_seqlen, shard_heads, hidden_dims)  # 重塑图像注意力结果
    img_attn = all2all_head2seq(img_attn)  # 将头的格式转换回序列格式
    img_attn = img_attn.reshape(shard_seqlen, -1)  # 重塑为 [shard_seqlen, -1] 形状

    torch.cuda.synchronize()  # 确保CUDA操作完成
    txt_attn = torch.cat(gathered_txt_attn, dim=1)  # 合并所有进程的文本注意力结果

    # 合并图像和文本的注意力结果
    attn = torch.cat([img_attn, txt_attn], dim=0)

    return attn  # 返回最终的注意力结果