test_cross_entropy_apex.py 1.45 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
import math

import torch
import torch.nn.functional as F
import pytest

from einops import rearrange

9
from flass_attn.losses.cross_entropy_apex import CrossEntropyLossApex
Tri Dao's avatar
Tri Dao committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8


@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('inplace_backward', [False, True])
# @pytest.mark.parametrize('inplace_backward', [False])
@pytest.mark.parametrize('vocab_size', [50257])
def test_cross_entropy_loss_apex(vocab_size, inplace_backward, dtype):
    device = 'cuda'
    rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 128
    x_pt = torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True)
    x = x_pt.detach().clone().requires_grad_()
    y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
    y[torch.randperm(batch_size * seqlen)[:10]] = -100
    model_pt = torch.nn.CrossEntropyLoss()
    model = CrossEntropyLossApex(inplace_backward=inplace_backward)
    out = model(x, y)
    out_pt = model_pt(x_pt.float(), y)
    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)

    g = torch.randn_like(out)
    out_pt.backward(g)
    out.backward(g)
    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)