test_mapping.py 3.53 KB
Newer Older
1
2
3
4
5
6
7
import logging

import torch
from torch.testing._internal import common_utils

from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import mappings
8
9
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
10

11
12

logging.getLogger("torch").setLevel(logging.WARNING)
13
14
15
logging.getLogger("apex").setLevel(logging.WARNING)


16
class MappingTestBase:
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    def test_reduce(self):
        for tensor_model_paralell_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_paralell_world_size > 0:
                continue
            with self.subTest(
                tensor_model_paralell_world_size=tensor_model_paralell_world_size
            ):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_paralell_world_size
                )
                t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}")
                expected = torch.full(
                    (10, 10, 10, 10),
                    50 * tensor_model_paralell_world_size,
                    device=f"cuda:{self.rank}",
                )
                self.assertTrue(torch.equal(mappings._reduce(t), expected))
                parallel_state.destroy_model_parallel()

    def test_split(self):
        for tensor_model_paralell_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_paralell_world_size > 0:
                continue
            with self.subTest(
                tensor_model_paralell_world_size=tensor_model_paralell_world_size
            ):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_paralell_world_size
                )

                tensors = [
                    torch.randn(10, 1)
                    for rank in range(tensor_model_paralell_world_size)
                ]
                x = torch.cat(tensors, 1)
52
                out = mappings._split_along_last_dim(x)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
                self.assertTrue(
                    torch.equal(
                        out, tensors[parallel_state.get_tensor_model_parallel_rank()]
                    )
                )
                parallel_state.destroy_model_parallel()

    def test_gather(self):
        for tensor_model_paralell_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_paralell_world_size > 0:
                continue
            with self.subTest(
                tensor_model_paralell_world_size=tensor_model_paralell_world_size
            ):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_paralell_world_size
                )
                device = f"cuda:{self.rank}"
71
                gathered = mappings._gather_along_last_dim(
72
73
74
75
76
77
78
79
80
81
82
83
                    torch.tensor(
                        [parallel_state.get_tensor_model_parallel_rank()], device=device
                    )
                )
                expected = torch.tensor(
                    [rank for rank in range(tensor_model_paralell_world_size)],
                    device=device,
                )
                self.assertTrue(torch.equal(gathered, expected))
                parallel_state.destroy_model_parallel()


84
85
86
87
class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): pass
class UccMappingTest(MappingTestBase, UccDistributedTestBase): pass


88
89
if __name__ == "__main__":
    common_utils.run_tests()