test_distcrossentropy.py 1.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import pytest
import torch
import torch.nn.functional as F

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy
from colossalai.testing import rerun_if_address_is_in_use, spawn

CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)


def check_dist_crossentropy(rank, world_size, port, ignore_index):
    disable_existing_loggers()
    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')

    # prepare data
    pred = torch.randn(2, 4, 8, requires_grad=True)
    labels = torch.randint(8, (2, 4))
    # set some label to -100 to test the ignore index
    labels[0, -1] = ignore_index

    org_pred = pred.view(-1, 8)
    org_labels = labels.view(-1)
    org_loss = F.cross_entropy(org_pred, org_labels)

    dist_pred = pred.chunk(world_size, -1)[rank]
    dist_loss = applyDistCrossEntropy(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index)

    assert torch.allclose(org_loss, dist_loss,
                          atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_crossentropy():
    ignore_index = -100
    spawn(check_dist_crossentropy, 2, ignore_index=ignore_index)


if __name__ == '__main__':
    test_dist_crossentropy()