test_cross_entropy_parallel.py 2.48 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/losses/test_cross_entropy_parallel.py

import math

import torch
import torch.nn.functional as F
import pytest

from apex.transformer import parallel_state
from apex.transformer import tensor_parallel

13
from flash_attn.losses.cross_entropy_parallel import CrossEntropyLossParallel
Tri Dao's avatar
Tri Dao committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8


@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('inplace_backward', [False, True])
# @pytest.mark.parametrize('inplace_backward', [False])
@pytest.mark.parametrize('vocab_size', [50264])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
def test_cross_entropy_loss_apex(vocab_size, world_size, inplace_backward, dtype):
    assert vocab_size % world_size == 0
    rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32
                  else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)))
    if not torch.distributed.is_initialized():
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
    partition_vocab_size = vocab_size // world_size
    device = f'cuda:{torch.distributed.get_rank()}'
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 128
    x_pt = (torch.randn(batch_size * seqlen, vocab_size, device=device,
                        dtype=dtype) * 10).requires_grad_()
    x = tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt).detach().clone().requires_grad_()
    y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
    y[torch.randperm(batch_size * seqlen)[:10]] = -100
    model_pt = torch.nn.CrossEntropyLoss(reduction='none')
    model = CrossEntropyLossParallel(reduction='none', inplace_backward=inplace_backward)
    out = model(x, y)
    out_pt = model_pt(x_pt.float(), y)
    assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)

    g = torch.randn_like(out)
    out_pt.backward(g)
    out.backward(g)
    assert torch.allclose(x.grad, x_pt.grad[:, (rank * partition_vocab_size):(rank + 1) * partition_vocab_size], rtol=rtol, atol=atol)

    parallel_state.destroy_model_parallel()