test_rotary.py 1.7 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
import math

Tri Dao's avatar
Tri Dao committed
3
import pytest
Tri Dao's avatar
Tri Dao committed
4
5
6
import torch
import torch.nn.functional as F
from einops import rearrange
7
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch
Tri Dao's avatar
Tri Dao committed
8

Tri Dao's avatar
Tri Dao committed
9
is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
Tri Dao's avatar
Tri Dao committed
10
11


Tri Dao's avatar
Tri Dao committed
12
13
14
@pytest.mark.parametrize(
    "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
Tri Dao's avatar
Tri Dao committed
15
# @pytest.mark.parametrize('dtype', ([torch.float16]))
Tri Dao's avatar
Tri Dao committed
16
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
Tri Dao's avatar
Tri Dao committed
17
# @pytest.mark.parametrize('rotary_fraction', [0.5])
Tri Dao's avatar
Tri Dao committed
18
@pytest.mark.parametrize("inplace", [False, True])
Tri Dao's avatar
Tri Dao committed
19
20
21
22
23
24
25
# @pytest.mark.parametrize('inplace', [False])
def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
    rtol = 1e-3
    batch_size = 32
    nheads = 4
    seqlen = 217
    headdim = 128
Tri Dao's avatar
Tri Dao committed
26
27
28
    x = torch.randn(
        batch_size, seqlen, nheads, headdim, dtype=dtype, device="cuda", requires_grad=True
    )
Tri Dao's avatar
Tri Dao committed
29
30
31
    x_pt = x.detach().clone().requires_grad_()
    rotary_dim = int(rotary_fraction * headdim)
    assert rotary_dim % 2 == 0
Tri Dao's avatar
Tri Dao committed
32
    angle = torch.randn(seqlen, rotary_dim // 2, device="cuda")
Tri Dao's avatar
Tri Dao committed
33
34
35
36
37
38
39
40
41
42
43
44
45
    cos = torch.cos(angle).to(dtype=dtype)
    sin = torch.sin(angle).to(dtype=dtype)
    out = apply_rotary_emb_func(x, cos, sin, inplace)
    out_pt = apply_rotary_emb_torch(x_pt, cos, sin)
    # Numerical error if we just do any arithmetic
    atol = ((out + 0.3 - 0.3) - out).abs().max().item()
    assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
    g = torch.randn_like(out)
    g_pt = g.clone()  # If inplace=True, we might modify the gradient inplace
    out.backward(g)
    out_pt.backward(g_pt)
    atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)