test_triton_decode_attention.py 5.37 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import pytest
import torch
zhuwenwen's avatar
zhuwenwen committed
5
import triton
6

zhuwenwen's avatar
zhuwenwen committed
7
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd, decode_attention_v1, decode_attention_v2
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

def cdiv(a, b):
    return (a + b - 1) // b


@pytest.mark.parametrize("B", [3, 5])
@pytest.mark.parametrize("L", [1027, 1025])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D_QK", [128, 192, 576])
@pytest.mark.parametrize("D_V", [128, 512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
    assert CACHE_SIZE % PAGE_SIZE == 0
    dtype = torch.bfloat16
    seq_len = L  # This represents the number of tokens already in the sequence
    sm_scale = 1.0 / (D_QK**0.5)
    num_kv_splits = 8

zhuwenwen's avatar
zhuwenwen committed
28
    num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) # 向上取整:65, (1027+16-1)//16
29
30
    req_to_page = torch.randint(0,
                                CACHE_SIZE // PAGE_SIZE,
zhuwenwen's avatar
zhuwenwen committed
31
                                (B, num_pages_per_batch, 1), #shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
32
33
                                device="cuda")
    req_to_token = req_to_page * PAGE_SIZE
zhuwenwen's avatar
zhuwenwen committed
34
    req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) # 维度扩展,从torch.Size([3, 65, 1])扩展至torch.Size([3, 65, 16])
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
        1, 1, -1)
    req_to_token = req_to_token.view(B, -1)
    req_to_token = req_to_token[:, :seq_len].contiguous()

    # q represents the new token being generated, one per batch
    q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")

    # k_buffer and v_buffer represent all previous tokens
    # Page size is 1.
    k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
    v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")

    # o will have the same shape as q
    o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
zhuwenwen's avatar
zhuwenwen committed
50
    
51
52
    b_seq_len = torch.full((B, ), seq_len, device="cuda")

zhuwenwen's avatar
zhuwenwen committed
53
54
55
56
57
58
    b_start_loc = torch.arange(0, k_buffer.shape[0] * PAGE_SIZE, k_buffer.shape[0] * PAGE_SIZE // q.shape[0], device="cuda").to(torch.int32)
    attn_logits_v1 = torch.empty(
               (q.shape[1], k_buffer.shape[0]*PAGE_SIZE),
               dtype=torch.float16,
               device="cuda")

59
60
61
62
63
    attn_logits = torch.empty(
        (B, H_Q, num_kv_splits, D_V + 1),
        dtype=torch.float32,
        device="cuda",
    )
zhuwenwen's avatar
zhuwenwen committed
64
65
    
    quantiles = [0.5, 0.2, 0.8]
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    # Call the original implementation.
    decode_attention_fwd(
        q,
        k_buffer,
        v_buffer,
        o,
        req_to_token,
        b_seq_len,
        attn_logits,
        num_kv_splits,
        sm_scale,
    )

    # Page size can be larger than 1.
    k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
    v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)

    o1 = torch.zeros_like(o)

    decode_attention_fwd(
        q,
        k_buffer,
        v_buffer,
        o1,
        req_to_page,
        b_seq_len,
        attn_logits,
        num_kv_splits,
        sm_scale,
        PAGE_SIZE,
    )
    assert torch.allclose(o, o1)
zhuwenwen's avatar
zhuwenwen committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    
    # v0_tc_ms, v0_tc_min_ms, v0_tc_max_ms = triton.testing.do_bench(lambda:
    # decode_attention_fwd(
    #     q,
    #     k_buffer,
    #     v_buffer,
    #     o1,
    #     req_to_page,
    #     b_seq_len,
    #     attn_logits,
    #     num_kv_splits,
    #     sm_scale,
    #     PAGE_SIZE,

    # ), quantiles=quantiles)
    # print("print mla decode attention ori kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v0_tc_ms)
    
    decode_attention_v1(
        q,
        k_buffer,
        v_buffer,
        o1,
        req_to_page,
        b_start_loc,
        b_seq_len,
        attn_logits_v1,
        num_kv_splits,
        sm_scale,
        PAGE_SIZE, 
    )
    assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2)
    
    # v1_tc_ms, v1_tc_min_ms, v1_tc_max_ms = triton.testing.do_bench(lambda:
    # decode_attention_v1(
    #     q,
    #     k_buffer,
    #     v_buffer,
    #     o1,
    #     req_to_page,
    #     b_start_loc,
    #     b_seq_len,
    #     attn_logits_v1,
    #     num_kv_splits,
    #     sm_scale,
    #     PAGE_SIZE, 
    # ), quantiles=quantiles)
    # print("print mla decode attention v1 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v1_tc_ms)
    
        
    decode_attention_v2(
        q,
        k_buffer,
        v_buffer,
        o1,
        req_to_page,
        b_seq_len,
        attn_logits,
        num_kv_splits,
        sm_scale,
        PAGE_SIZE, 
    )
    assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2)

    # v2_tc_ms, v2_tc_min_ms, v2_tc_max_ms = triton.testing.do_bench(lambda:
    # decode_attention_v2(
    #     q,
    #     k_buffer,
    #     v_buffer,
    #     o1,
    #     req_to_page,
    #     b_seq_len,
    #     attn_logits,
    #     num_kv_splits,
    #     sm_scale,
    #     PAGE_SIZE, 
    # ), quantiles=quantiles)
    # print("print mla decode attention v2 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v2_tc_ms)