test.py 3.14 KB
Newer Older
Xinchi Huang's avatar
Xinchi Huang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.distributed as dist
from lightx2v.attentions import attention
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
from lightx2v.attentions.distributed.ring.attn import ring_attn_sub, update_out_and_lse
from lightx2v.attentions.distributed.comm.ring_comm import RingComm

RING_COMM = None


def init_ring_comm():
    global RING_COMM
    RING_COMM = RingComm()


def base_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, lq, lk):
    attn_out = attention(
        q=q,
        k=k,
        v=v,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_k,
        max_seqlen_q=lq,
        max_seqlen_kv=lk,
    )
    return attn_out


def ring_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, lq, lk, ring_size):
    out, lse = None, None
    # q = torch.chunk(q, ring_size)
    q = q.unsqueeze(0)
    k = k.unsqueeze(0)
    v = v.unsqueeze(0)

    k = torch.chunk(k, ring_size, dim=1)
    v = torch.chunk(v, ring_size, dim=1)

    for i in range(ring_size):
        k_block, v_block = k[i], v[i]
        block_out, block_lse = ring_attn_sub(q, k_block, v_block)
        out, lse = update_out_and_lse(out, lse, block_out, block_lse)

    attn_out = out.to(torch.bfloat16).squeeze(0).reshape(lq, -1)
    return attn_out


def ring_attention_dist(q, k, v, cu_seqlens_q, cu_seqlens_k, lq, lk):
    if RING_COMM is None:
        init_ring_comm()

    out, lse = None, None
    # q = torch.chunk(q, ring_size)
    cur_rank = dist.get_rank()
    world_size = dist.get_world_size()

    out, lse, next_k, next_v = None, None, None, None

    q = q.unsqueeze(0)
    k = k.unsqueeze(0)
    v = v.unsqueeze(0)

    k = torch.chunk(k, world_size, dim=1)[cur_rank]
    v = torch.chunk(v, world_size, dim=1)[cur_rank]

    for step in range(world_size):
        if step + 1 != world_size:
            next_k = RING_COMM.send_recv(k)
            next_v = RING_COMM.send_recv(v)
            RING_COMM.commit()
        block_out, block_lse = ring_attn_sub(q, k, v)
        out, lse = update_out_and_lse(out, lse, block_out, block_lse)

        if step + 1 != world_size:
            RING_COMM.wait()
            k = next_k
            v = next_v

    attn_out = out.to(torch.bfloat16).squeeze(0).reshape(lq, -1)
    return attn_out


def test():
    q = torch.randn((32760, 12, 128), dtype=torch.bfloat16, device="cuda")
    k = torch.randn((32760, 12, 128), dtype=torch.bfloat16, device="cuda")
    v = torch.randn((32760, 12, 128), dtype=torch.bfloat16, device="cuda")
    cu_seqlens_q = torch.tensor([0, 32760], dtype=torch.int32, device="cuda")
    cu_seqlens_k = torch.tensor([0, 32760], dtype=torch.int32, device="cuda")
    lq = 32760
    lk = 32760

    base_attn = base_attention(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, lq=lq, lk=lk)

    ring_attn = ring_attention(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, lq=lq, lk=lk, ring_size=4)
    # import pdb; pdb.set_trace()
    # 添加断言以确认数值相同
    assert torch.allclose(base_attn, ring_attn, rtol=1e-3, atol=1e-3), "base_attn 和 ring_attn 的数值不相同!"


if __name__ == "__main__":
    # dist.init_process_group(backend="nccl")
    test()