test_cross_entropy_parallel.py 2.73 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/losses/test_cross_entropy_parallel.py

import math

import torch
import torch.nn.functional as F
import pytest

from apex.transformer import parallel_state
from apex.transformer import tensor_parallel

13
from flash_attn.losses.cross_entropy import CrossEntropyLoss
Tri Dao's avatar
Tri Dao committed
14
15
16
17
18

is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8


@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
19
# @pytest.mark.parametrize('dtype', [torch.float16])
Tri Dao's avatar
Tri Dao committed
20
21
@pytest.mark.parametrize('inplace_backward', [False, True])
# @pytest.mark.parametrize('inplace_backward', [False])
22
23
@pytest.mark.parametrize('smoothing', [0.0, 0.9])
# @pytest.mark.parametrize('smoothing', [0.9])
Tri Dao's avatar
Tri Dao committed
24
25
26
@pytest.mark.parametrize('vocab_size', [50264])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
27
def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype):
Tri Dao's avatar
Tri Dao committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    assert vocab_size % world_size == 0
    rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32
                  else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)))
    if not torch.distributed.is_initialized():
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
    partition_vocab_size = vocab_size // world_size
    device = f'cuda:{torch.distributed.get_rank()}'
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 128
    x_pt = (torch.randn(batch_size * seqlen, vocab_size, device=device,
                        dtype=dtype) * 10).requires_grad_()
    x = tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt).detach().clone().requires_grad_()
    y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
    y[torch.randperm(batch_size * seqlen)[:10]] = -100
47
48
    model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction='none')
    model = CrossEntropyLoss(label_smoothing=smoothing, reduction='none',
49
50
51
                             inplace_backward=inplace_backward,
                             process_group=parallel_state.get_tensor_model_parallel_group())
    out = model(x, y)
Tri Dao's avatar
Tri Dao committed
52
53
54
55
56
57
58
59
60
    out_pt = model_pt(x_pt.float(), y)
    assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)

    g = torch.randn_like(out)
    out_pt.backward(g)
    out.backward(g)
    assert torch.allclose(x.grad, x_pt.grad[:, (rank * partition_vocab_size):(rank + 1) * partition_vocab_size], rtol=rtol, atol=atol)

    parallel_state.destroy_model_parallel()