test_rotary.py 1.7 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
import math

import torch
import torch.nn.functional as F
import pytest

from einops import rearrange

9
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch
Tri Dao's avatar
Tri Dao committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44


is_sm8x = torch.cuda.get_device_capability('cuda') >= (8, 0)

@pytest.mark.parametrize('dtype', ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize('rotary_fraction', [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [0.5])
@pytest.mark.parametrize('inplace', [False, True])
# @pytest.mark.parametrize('inplace', [False])
def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
    rtol = 1e-3
    batch_size = 32
    nheads = 4
    seqlen = 217
    headdim = 128
    x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device='cuda',
                    requires_grad=True)
    x_pt = x.detach().clone().requires_grad_()
    rotary_dim = int(rotary_fraction * headdim)
    assert rotary_dim % 2 == 0
    angle = torch.randn(seqlen, rotary_dim // 2, device='cuda')
    cos = torch.cos(angle).to(dtype=dtype)
    sin = torch.sin(angle).to(dtype=dtype)
    out = apply_rotary_emb_func(x, cos, sin, inplace)
    out_pt = apply_rotary_emb_torch(x_pt, cos, sin)
    # Numerical error if we just do any arithmetic
    atol = ((out + 0.3 - 0.3) - out).abs().max().item()
    assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
    g = torch.randn_like(out)
    g_pt = g.clone()  # If inplace=True, we might modify the gradient inplace
    out.backward(g)
    out_pt.backward(g_pt)
    atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)