attn.py 3.27 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.distributed as dist
from lightx2v.attentions import attention



def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"):
    '''
    执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。

    参数:
        q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
        k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
        v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
        img_qkv_len (int): 图像查询、键和值的长度
        cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
        attention_type (str): 注意力类型,默认为 "flash_attn2"

    返回:
        torch.Tensor: 计算得到的注意力结果
    '''
    # 获取当前进程的排名和全局进程数
    cur_rank = dist.get_rank()
    world_size = dist.get_world_size()

    # 获取序列长度和文本相关的长度
    seq_len = q.shape[0]
    txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len  # 文本查询、键和值的长度
    txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len  # 文本掩码长度
    
    # 获取查询张量的头数和隐藏维度
    _, heads, hidden_dims = q.shape
    shard_heads = heads // world_size  # 每个进程处理的头数
    shard_seqlen = img_qkv_len  # 每个进程处理的序列长度

    # 分割图像和文本的查询、键和值
    img_q, img_k, img_v = q[:img_qkv_len,:,:].contiguous(), k[:img_qkv_len,:,:].contiguous(), v[:img_qkv_len,:,:].contiguous()
    txt_q, txt_k, txt_v = q[img_qkv_len:,:,:].contiguous(), k[img_qkv_len:,:,:].contiguous(), v[img_qkv_len:,:,:].contiguous()

    gathered_img_k = [torch.empty_like(img_k) for _ in range(world_size)]
    gathered_img_v = [torch.empty_like(img_v) for _ in range(world_size)]

    dist.all_gather(gathered_img_k, img_k)
    dist.all_gather(gathered_img_v, img_v)
    torch.cuda.synchronize()

    q = q
    k = torch.cat(gathered_img_k+[txt_k], dim=0)
    v = torch.cat(gathered_img_v+[txt_v], dim=0)

    # 初始化累积序列长度张量
    cu_seqlens_q = torch.zeros([3], dtype=torch.int32, device="cuda")
    s = txt_qkv_len + img_q.shape[0]  # 计算文本和图像的总长度
    s1 = s  # 当前样本的结束位置
    s2 = txt_mask_len + img_q.shape[0]  # 文本掩码的结束位置
    cu_seqlens_q[1] = s1  # 设置累积序列长度
    cu_seqlens_q[2] = s2  # 设置累积序列长度
    max_seqlen_q = img_q.shape[0] + txt_q.shape[0]  # 最大序列长度

    # 初始化累积序列长度张量
    cu_seqlens_kv = torch.zeros([3], dtype=torch.int32, device="cuda")
    s = txt_qkv_len + img_k.shape[0]*world_size  # 计算文本和图像的总长度
    s1 = s  # 当前样本的结束位置
    s2 = txt_mask_len + img_k.shape[0]*world_size  # 文本掩码的结束位置
    cu_seqlens_kv[1] = s1  # 设置累积序列长度
    cu_seqlens_kv[2] = s2  # 设置累积序列长度
    max_seqlen_kv = img_k.shape[0]*world_size + txt_q.shape[0]  # 最大序列长度

    attn = attention(
        attention_type=attention_type,
        q=q,
        k=k,
        v=v,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
        max_seqlen_q=max_seqlen_q,
        max_seqlen_kv=max_seqlen_kv
    )

    return attn