test_mappings.py 6.13 KB
Newer Older
liangjing's avatar
liangjing committed
1
2
import torch

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
3
from megatron.core.tensor_parallel import mappings
4
from tests.unit_tests.test_utilities import Utils
liangjing's avatar
liangjing committed
5

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
6
7

def test_CopyToModelParallelRegion():
liangjing's avatar
liangjing committed
8
9
    Utils.initialize_model_parallel(4, 2)
    input_data = torch.ones((1)).cuda() * Utils.rank
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
10
11
12
    output_data = mappings._CopyToModelParallelRegion.backward(None, input_data)
    result = torch.ones(1).cuda()
    result = result * 22 if Utils.rank >= 4 else result * 6
liangjing's avatar
liangjing committed
13
14
15
    assert torch.equal(output_data, result)
    assert torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data))
    assert torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data))
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
16
17
    Utils.destroy_model_parallel()

liangjing's avatar
liangjing committed
18

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
19
def test_ReduceFromModelParallelRegion():
liangjing's avatar
liangjing committed
20
21
    Utils.initialize_model_parallel(4, 2)
    input_data = torch.ones((1)).cuda() * Utils.rank
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
22
23
24
    output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data)
    result = torch.ones(1).cuda()
    result = result * 22 if Utils.rank >= 4 else result * 6
liangjing's avatar
liangjing committed
25
26
27
28
29
30
    assert torch.equal(output_data, result)
    input_data = torch.ones((1)).cuda() * Utils.rank
    assert torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result)
    assert torch.equal(
        input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data)
    )
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
31
32
    Utils.destroy_model_parallel()

liangjing's avatar
liangjing committed
33

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
34
def test_ScatterToModelParallelRegion():
liangjing's avatar
liangjing committed
35
36
    Utils.initialize_model_parallel(4, 2)
    input_data = torch.rand((8, 4)).cuda()
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
37
    output_data = mappings.scatter_to_tensor_model_parallel_region(input_data)
liangjing's avatar
liangjing committed
38
39
    req_dim = int(Utils.rank % (Utils.world_size / 2))
    assert torch.equal(output_data, input_data[:, req_dim].reshape((8, 1)))
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
40
    output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data)
liangjing's avatar
liangjing committed
41
    assert torch.equal(output_data, input_data[:, req_dim].reshape((8, 1)))
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
42
43
44

    input_data = torch.ones(8).cuda() * Utils.rank
    actual_output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
liangjing's avatar
liangjing committed
45
46
47
48
    expected_output = torch.cat(
        (torch.ones(8) * 0, torch.ones(8) * 1, torch.ones(8) * 2, torch.ones(8) * 3)
    ).cuda()
    if Utils.rank >= 4:
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
49
        expected_output = expected_output + 4
liangjing's avatar
liangjing committed
50
    assert torch.equal(actual_output_data, expected_output)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
51
52
    Utils.destroy_model_parallel()

liangjing's avatar
liangjing committed
53

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
54
def test_GatherFromModelParallelRegion():
liangjing's avatar
liangjing committed
55
56
57
    Utils.initialize_model_parallel(4, 2)
    input_data = torch.rand((8, 4)).cuda()
    req_dim = int(Utils.rank % (Utils.world_size / 2))
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
58
    output_data = mappings._GatherFromModelParallelRegion.backward(None, input_data)
liangjing's avatar
liangjing committed
59
    assert torch.equal(output_data, input_data[:, req_dim].reshape((8, 1)))
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
60
61
    input_data = torch.ones(8).cuda() * Utils.rank
    actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data)
liangjing's avatar
liangjing committed
62
63
64
65
    expected_output = torch.cat(
        (torch.ones(8) * 0, torch.ones(8) * 1, torch.ones(8) * 2, torch.ones(8) * 3)
    ).cuda()
    if Utils.rank >= 4:
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
66
        expected_output = expected_output + 4
liangjing's avatar
liangjing committed
67
68
69
70
    assert torch.equal(actual_output_data, expected_output)
    assert torch.equal(
        mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output
    )
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
71
    Utils.destroy_model_parallel()
liangjing's avatar
liangjing committed
72
73


Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
74
def test_ScatterToSequenceParallelRegion():
liangjing's avatar
liangjing committed
75
76
77
    Utils.initialize_model_parallel(4, 2)
    input_data = torch.rand((8, 4)).cuda()
    req_dim = int(Utils.rank % (Utils.world_size / 2)) * 2
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
78
    output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data)
liangjing's avatar
liangjing committed
79
    assert torch.equal(output_data, input_data[req_dim : req_dim + 2, :])
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
80
    output_data = mappings.scatter_to_sequence_parallel_region(input_data)
liangjing's avatar
liangjing committed
81
    assert torch.equal(output_data, input_data[req_dim : req_dim + 2, :])
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
82
83
    input_data = torch.ones(4).cuda() * Utils.rank
    output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
liangjing's avatar
liangjing committed
84
85
86
87
    expected_output = torch.concat(
        (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3)
    ).cuda()
    if Utils.rank >= 4:
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
88
        expected_output = expected_output + 4
liangjing's avatar
liangjing committed
89
    assert torch.equal(output_data, expected_output)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
90
91
    Utils.destroy_model_parallel()

liangjing's avatar
liangjing committed
92

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
93
def test_GatherFromSequenceParallelRegion():
liangjing's avatar
liangjing committed
94
    Utils.initialize_model_parallel(4, 2)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
95
96
    input_data = torch.ones(4).cuda() * Utils.rank
    output_data = mappings.gather_from_sequence_parallel_region(input_data)
liangjing's avatar
liangjing committed
97
98
99
100
    expected_output = torch.concat(
        (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3)
    ).cuda()
    if Utils.rank >= 4:
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
101
        expected_output = expected_output + 4
liangjing's avatar
liangjing committed
102
103
104
105
106
107
108
109
    assert torch.equal(output_data, expected_output)
    assert torch.equal(
        mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output
    )
    input_data = torch.vstack(
        (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3)
    ).cuda()

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
110
111
    class Ctx:
        tensor_parallel_output_grad = True
liangjing's avatar
liangjing committed
112
113
        output_split_sizes = None

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
114
    output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data)
liangjing's avatar
liangjing committed
115
116
    expected_output = torch.ones((1, 4)).cuda() * 4 * int(Utils.rank % 4)
    assert torch.equal(output_data[0], expected_output)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
117
118
    Utils.destroy_model_parallel()

liangjing's avatar
liangjing committed
119

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
120
def test_ReduceScatterToSequenceParallelRegion():
liangjing's avatar
liangjing committed
121
122
123
124
    Utils.initialize_model_parallel(4, 2)
    input_data = torch.vstack(
        (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3)
    ).cuda()
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
125
126
    output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data)
    expected_output = torch.ones(4).cuda() * 4 * int(Utils.rank % 4)
liangjing's avatar
liangjing committed
127
128
129
130
131
    assert torch.equal(output_data[0], expected_output)
    assert torch.equal(
        mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data),
        expected_output.reshape((1, 4)),
    )
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
132
    input_data = torch.ones(4).cuda() * Utils.rank
liangjing's avatar
liangjing committed
133
134
135
136
137
138
139
140
141

    class Ctx:
        input_split_sizes = None

    output_data, _ = mappings._ReduceScatterToSequenceParallelRegion.backward(Ctx(), input_data)
    expected_output = torch.concat(
        (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3)
    ).cuda()
    if Utils.rank >= 4:
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
142
        expected_output = expected_output + 4
liangjing's avatar
liangjing committed
143
    assert torch.equal(output_data, expected_output)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
144
    Utils.destroy_model_parallel()