attn.py 3.22 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
import torch
import torch.distributed as dist
from lightx2v.attentions import attention


def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"):
Dongz's avatar
Dongz committed
7
    """
helloyongyang's avatar
helloyongyang committed
8
9
10
11
12
13
14
15
16
17
18
19
    执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。

    参数:
        q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
        k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
        v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
        img_qkv_len (int): 图像查询、键和值的长度
        cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
        attention_type (str): 注意力类型,默认为 "flash_attn2"

    返回:
        torch.Tensor: 计算得到的注意力结果
Dongz's avatar
Dongz committed
20
    """
helloyongyang's avatar
helloyongyang committed
21
22
23
24
25
26
27
28
    # 获取当前进程的排名和全局进程数
    cur_rank = dist.get_rank()
    world_size = dist.get_world_size()

    # 获取序列长度和文本相关的长度
    seq_len = q.shape[0]
    txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len  # 文本查询、键和值的长度
    txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len  # 文本掩码长度
Dongz's avatar
Dongz committed
29

helloyongyang's avatar
helloyongyang committed
30
31
32
33
34
35
    # 获取查询张量的头数和隐藏维度
    _, heads, hidden_dims = q.shape
    shard_heads = heads // world_size  # 每个进程处理的头数
    shard_seqlen = img_qkv_len  # 每个进程处理的序列长度

    # 分割图像和文本的查询、键和值
Dongz's avatar
Dongz committed
36
37
    img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
    txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
helloyongyang's avatar
helloyongyang committed
38
39
40
41
42
43
44
45
46

    gathered_img_k = [torch.empty_like(img_k) for _ in range(world_size)]
    gathered_img_v = [torch.empty_like(img_v) for _ in range(world_size)]

    dist.all_gather(gathered_img_k, img_k)
    dist.all_gather(gathered_img_v, img_v)
    torch.cuda.synchronize()

    q = q
Dongz's avatar
Dongz committed
47
48
    k = torch.cat(gathered_img_k + [txt_k], dim=0)
    v = torch.cat(gathered_img_v + [txt_v], dim=0)
helloyongyang's avatar
helloyongyang committed
49
50
51
52
53
54
55
56
57
58
59
60

    # 初始化累积序列长度张量
    cu_seqlens_q = torch.zeros([3], dtype=torch.int32, device="cuda")
    s = txt_qkv_len + img_q.shape[0]  # 计算文本和图像的总长度
    s1 = s  # 当前样本的结束位置
    s2 = txt_mask_len + img_q.shape[0]  # 文本掩码的结束位置
    cu_seqlens_q[1] = s1  # 设置累积序列长度
    cu_seqlens_q[2] = s2  # 设置累积序列长度
    max_seqlen_q = img_q.shape[0] + txt_q.shape[0]  # 最大序列长度

    # 初始化累积序列长度张量
    cu_seqlens_kv = torch.zeros([3], dtype=torch.int32, device="cuda")
Dongz's avatar
Dongz committed
61
    s = txt_qkv_len + img_k.shape[0] * world_size  # 计算文本和图像的总长度
helloyongyang's avatar
helloyongyang committed
62
    s1 = s  # 当前样本的结束位置
Dongz's avatar
Dongz committed
63
    s2 = txt_mask_len + img_k.shape[0] * world_size  # 文本掩码的结束位置
helloyongyang's avatar
helloyongyang committed
64
65
    cu_seqlens_kv[1] = s1  # 设置累积序列长度
    cu_seqlens_kv[2] = s2  # 设置累积序列长度
Dongz's avatar
Dongz committed
66
    max_seqlen_kv = img_k.shape[0] * world_size + txt_q.shape[0]  # 最大序列长度
helloyongyang's avatar
helloyongyang committed
67

Dongz's avatar
Dongz committed
68
    attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv)
helloyongyang's avatar
helloyongyang committed
69
70

    return attn