test_flex_attention.py 9.65 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
import pytest
import torch
import torch.nn.functional as F

from test.utils import assert_verbose_allclose
from test.utils import set_seed
from test.utils import supports_bfloat16
from torch.nn.attention.flex_attention import create_block_mask
from torch.nn.attention.flex_attention import create_mask
from torch.nn.attention.flex_attention import flex_attention

from liger_kernel.utils import infer_device


def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


def prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index):
    return (~((q_idx >= rejected_index[b]) & (chosen_index[b] <= kv_idx) & (kv_idx < rejected_index[b]))) & (
        q_idx >= kv_idx
    )


device = infer_device()
set_seed(42)


def _test_correctness_flex(B, H, S, D, mask_func, dtype, atol, rtol, device=infer_device()):
    """
    Test attention mechanisms with various implementations.

    Parameters:
        B (int): Batch size
        H (int): Number of attention heads
        S (int): Sequence length
        D (int): Hidden dimension per head
        mask_func: A function that generates custom attention mask
        dtype: Data type for computation
        atol (float): Absolute tolerance for comparison
        rtol (float): Relative tolerance for comparison
    """
    torch.manual_seed(0)

    # Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input)
    query_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
    key_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
    value_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)

    query_flex = query_torch.clone().detach().requires_grad_(True)
    key_flex = key_torch.clone().detach().requires_grad_(True)
    value_flex = value_torch.clone().detach().requires_grad_(True)

    block_mask = create_block_mask(mask_func, B, H, S, S, device=device)  # Sparsity block mask
    mask = create_mask(mask_func, B, H, S, S, device=device)  # Regular mask

    # If you are using a causal mask with FA2, you can enable `is_causal`."
    # e.g.,
    # F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)

    torch_out = F.scaled_dot_product_attention(query_torch, key_torch, value_torch, attn_mask=mask)

    flex_out = flex_attention(query_flex, key_flex, value_flex, block_mask=block_mask)

    # Check forward pass
    assert_verbose_allclose(flex_out, torch_out, atol=atol, rtol=rtol)

    grad_out = torch.ones_like(torch_out)
    torch_out.backward(grad_out)
    flex_out.backward(grad_out)

    # Check gradients
    assert_verbose_allclose(query_flex.grad, query_torch.grad, atol=atol, rtol=rtol)
    assert_verbose_allclose(key_flex.grad, key_torch.grad, atol=atol, rtol=rtol)
    assert_verbose_allclose(value_flex.grad, value_torch.grad, atol=atol, rtol=rtol)


def _is_flex_attention_supported():
    """Check if flex attention is supported on the current device"""
    device = infer_device()
    return device in ["cuda"]


@pytest.mark.skipif(not _is_flex_attention_supported(), reason="FlexAttention is only supported on CUDA or CPU devices")
@pytest.mark.parametrize(
    "B, H, S, D",
    [
        (2, 8, 1024, 32),
        (3, 12, 2048, 64),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        pytest.param(
            torch.bfloat16,
            3e-2,
            5e-1,
            marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
        ),
        (torch.float16, 1e-2, 5e-3),
        (torch.float32, 1e-3, 5e-4),
    ],
)
def test_correctness_flex(B, H, S, D, dtype, atol, rtol):
    _test_correctness_flex(B, H, S, D, causal_mask, dtype, atol, rtol)

    # Roughly generate custom rejected and chosen indices for each batch
    chosen_index = torch.randint(0, S // 2, (B,), device=infer_device())
    rejected_index = torch.randint(S // 2, S, (B,), device=infer_device())

    def wrapped_prefix_mask(b, h, q_idx, kv_idx):
        return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index)

    _test_correctness_flex(B, H, S, D, wrapped_prefix_mask, dtype, atol, rtol)


def _test_correctness_prefix(
    B=2,
    H=8,
    P=512,
    C=256,
    R=256,
    D=32,
    dtype=torch.float32,
    atol=1e-3,
    rtol=5e-4,
    device=infer_device(),
):
    """
    Test that prefix sharing attention matches separate computations (i.e. two separate casual masked attention, prefix+chosen and prefix+rejected).
    The mental model is:

    A. prefix + chosen
    P
    P P
    P P P
    P P P C
    P P P C C
    P P P C C C

    B. prefix + rejected
    P
    P P
    P P P
    P P P R
    P P P R R
    P P P R R R

    C. shared prefix + chosen + rejected
    P
    P P
    P P P
    P P P C
    P P P C C
    P P P C C C
    P P P       R
    P P P       R R
    P P P       R R R


    We test them as below to ensure attention value equivalence:
    1. prefix of shared attn (upper of C.) == prefix of chosen attn (upper of A.)
    2. prefix of shared attn (upper of C.) == prefix of rejected attn (upper of B.)
    P       P
    P P   = P P
    P P P   P P P

    3. prefix of shared attn (middle right of C.) == prefix of chosen attn (lower right of A.)
    C       C
    C C   = C C
    C C C   C C C

    4. prefix of shared attn (lower right of C.) == prefix of rejected attn (lower right of B.)
    R       R
    R R   = R R
    R R R   R R R

    Args:
        B: batch size
        H: number of heads
        P: prefix length
        C: chosen response length
        R: rejected response length
        D: hidden dimension per head
    """
    torch.manual_seed(0)

    # Total sequence length for shared version
    S = P + C + R

    # Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input)
    query = torch.randn(B, H, S, D, device=device, dtype=dtype)
    key = torch.randn(B, H, S, D, device=device, dtype=dtype)
    value = torch.randn(B, H, S, D, device=device, dtype=dtype)

    # Split tensors for separate computation
    query_prefix = query[:, :, :P, :]
    key_prefix = key[:, :, :P, :]
    value_prefix = value[:, :, :P, :]

    query_chosen = query[:, :, P : P + C, :]
    key_chosen = key[:, :, P : P + C, :]
    value_chosen = value[:, :, P : P + C, :]

    query_rejected = query[:, :, P + C :, :]
    key_rejected = key[:, :, P + C :, :]
    value_rejected = value[:, :, P + C :, :]

    chosen_index = torch.full((B,), P + C, device=device)
    rejected_index = torch.full((B,), S, device=device)

    def wrapped_prefix_mask(b, h, q_idx, kv_idx):
        return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index)

    block_mask = create_block_mask(wrapped_prefix_mask, B, H, S, S, device=device)
    shared_out = flex_attention(query, key, value, block_mask=block_mask)

    # Compute attention for prefix + chosen separately
    PC = P + C
    query_pc = torch.cat([query_prefix, query_chosen], dim=2)
    key_pc = torch.cat([key_prefix, key_chosen], dim=2)
    value_pc = torch.cat([value_prefix, value_chosen], dim=2)

    def causal_mask(b, h, q_idx, kv_idx):
        return q_idx >= kv_idx

    pc_block_mask = create_block_mask(causal_mask, B, H, PC, PC, device=device)
    pc_out = flex_attention(query_pc, key_pc, value_pc, block_mask=pc_block_mask)

    # Compute attention for prefix + rejected separately
    PR = P + R
    query_pr = torch.cat([query_prefix, query_rejected], dim=2)
    key_pr = torch.cat([key_prefix, key_rejected], dim=2)
    value_pr = torch.cat([value_prefix, value_rejected], dim=2)

    pr_block_mask = create_block_mask(causal_mask, B, H, PR, PR, device=device)
    pr_out = flex_attention(query_pr, key_pr, value_pr, block_mask=pr_block_mask)

    shared_prefix = shared_out[:, :, :P, :P]
    shared_chosen = shared_out[:, :, P : P + C, P : P + C]
    shared_rejected = shared_out[:, :, P + C :, P + C :]

    separate_prefix_c = pc_out[:, :, :P, :P]
    separate_chosen = pc_out[:, :, P:, P:]
    separate_prefix_r = pr_out[:, :, :P, :P]
    separate_rejected = pr_out[:, :, P:, P:]

    # Verify prefix outputs are identical
    assert torch.allclose(shared_prefix, separate_prefix_c, atol=atol, rtol=rtol), (
        "Prefix attention from shared computation doesn't match prefix+chosen computation"
    )
    assert torch.allclose(shared_prefix, separate_prefix_r, atol=atol, rtol=rtol), (
        "Prefix attention from shared computation doesn't match prefix+rejected computation"
    )

    # Verify chosen and rejected outputs
    assert torch.allclose(shared_chosen, separate_chosen, atol=atol, rtol=rtol), (
        "Chosen response attention doesn't match between shared and separate computation"
    )
    assert torch.allclose(shared_rejected, separate_rejected, atol=atol, rtol=rtol), (
        "Rejected response attention doesn't match between shared and separate computation"
    )

    print("All attention values match between shared and separate computations!")


@pytest.mark.skipif(not _is_flex_attention_supported(), reason="FlexAttention is only supported on CUDA or CPU devices")
@pytest.mark.parametrize(
    "B, H, P, C, R, D",
    [
        (2, 8, 512, 256, 256, 32),
        (3, 12, 1024, 512, 512, 64),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        pytest.param(
            torch.bfloat16,
            3e-2,
            5e-1,
            marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
        ),
        (torch.float16, 1e-2, 5e-3),
        (torch.float32, 1e-3, 5e-4),
    ],
)
def test_correctness_prefix(B, H, P, C, R, D, dtype, atol, rtol):
    """Parametrized test for different configurations"""
    _test_correctness_prefix(B=B, H=H, P=P, C=C, R=R, D=D, dtype=dtype, atol=atol, rtol=rtol)