test_sharded_linear.py 8.87 KB
Newer Older
1
import pytest
2
import torch
3
import torch.nn.functional as F
4
5

import colossalai
6
7
from colossalai.device.device_mesh import DeviceMesh
from colossalai.nn._ops._utils import gather_forward_split_backward
8
9
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from colossalai.tensor.sharding_spec import ShardingSpec
10
from colossalai.testing import rerun_if_address_is_in_use, spawn
11
12
13
14
15
16
17


def run_dist(rank, world_size, port):
    config = {}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')

    # create mlp vars
18
    x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda()
19
20
21
22
23
24
25
26
27
28
    w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda()
    b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda()

    # run normal forward
    out = F.linear(x, w, b)

    # create mesh meta
    # the mesh is in the following topo
    # [[0, 1],
    #  [2, 3]]
29
    physical_mesh_id = torch.arange(0, 4)
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    mesh_shape = (2, 2)
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
    row_id = rank // 2
    column_id = rank % 2

    # create pg
    row_process_group = None
    col_process_group = None
    row_to_ranks = {0: [0, 1], 1: [2, 3]}
    col_to_ranks = {0: [0, 2], 1: [1, 3]}

    for idx in range(2):
        # row ranks
        row_ranks = row_to_ranks[idx]
        row_pg = ProcessGroup(ranks=row_ranks, tp_degree=2)

        # col ranks
        col_ranks = col_to_ranks[idx]
        col_pg = ProcessGroup(ranks=col_ranks, tp_degree=2)

        if rank in row_ranks:
            row_process_group = row_pg

        if rank in col_ranks:
            col_process_group = col_pg

    ########################
    #  RRR x RS0 -> RRS0 #
    ########################
    # w will be transposed in F.linear
    x_replica = x.detach().clone()
    w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[row_id]
    b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[row_id]

    # adding sharding spec
    x_replica.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={})
    w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [0]})
    b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [0]})

    # check sharding spec
    assert str(x_replica.sharding_spec.sharding_sequence) == "[R, R, R]"
    assert str(w_shard.sharding_spec.sharding_sequence) == "[S0, R]"
    assert str(b_shard.sharding_spec.sharding_sequence) == "[S0]"

    w_shard.pg_axis0 = col_process_group
    w_shard.pg_axis1 = row_process_group

    out_shard = F.linear(x_replica, w_shard, b_shard)
    assert str(out_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"

    # each row only has a mini-batch
    expected_out_shard = torch.chunk(out, chunks=2, dim=2)[row_id]
    assert torch.allclose(out_shard, expected_out_shard)

    ########################
    #  S0RR x RS1 -> S0RS1 #
    ########################
    # w will be transposed in F.linear
    x_shard = torch.chunk(x.detach().clone(), chunks=2, dim=0)[row_id]
    w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[column_id]
    b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[column_id]

    # adding sharding spec
    x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0]})
    w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1]})
    b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]})

    # check sharding spec
    assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, R]"
    assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, R]"
    assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]"

    w_shard.pg_axis0 = col_process_group
    w_shard.pg_axis1 = row_process_group

    out_shard = F.linear(x_shard, w_shard, b_shard)

    # each row only has a mini-batch
    expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id]
    expected_out_shard = torch.chunk(expected_out_shard, chunks=2, dim=2)[column_id]
    assert torch.allclose(out_shard, expected_out_shard)

    ########################
    #  S0RS1 x S1R -> S0RR #
    ########################
    # w will be transposed in F.linear
    x_shard = torch.chunk(x.clone(), chunks=2, dim=0)[row_id]
    x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id]
    w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id]
    b_replica = b.clone()

    # adding sharding spec
    x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0], 2: [1]})
    w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]})
    b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})

    # check sharding spec
    assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, S1]"
    assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]"
    assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"

    w_shard.pg_axis0 = col_process_group
    w_shard.pg_axis1 = row_process_group

    out_shard = F.linear(x_shard, w_shard, b_replica)

    # each row only has a mini-batch
    expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id]
    assert torch.allclose(out_shard, expected_out_shard)

    ########################
    #  RRS0 x S0R -> RRR #
    ########################
    # w will be transposed in F.linear
    x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id]
    w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id]
    b_replica = b.clone()

    # adding sharding spec
    x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]})
    w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [0]})
    b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})

    # check sharding spec
    assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
    assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S0]"
    assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"

    w_shard.pg_axis0 = col_process_group
    w_shard.pg_axis1 = row_process_group

    out_shard = F.linear(x_shard, w_shard, b_replica)

    # each row only has a mini-batch
    expected_out_shard = out
    assert torch.allclose(out_shard, expected_out_shard)

    ########################
    #  RS0S1 x S1R -> RS0R #
    ########################
    # w will be transposed in F.linear
    x_shard = torch.chunk(x.clone(), chunks=2, dim=1)[row_id]
    x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id]
    w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id]
    b_replica = b.clone()

    # adding sharding spec
    x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={1: [0], 2: [1]})
    w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]})
    b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})

    # check sharding spec
    assert str(x_shard.sharding_spec.sharding_sequence) == "[R, S0, S1]"
    assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]"
    assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"

    w_shard.pg_axis0 = col_process_group
    w_shard.pg_axis1 = row_process_group

    out_shard = F.linear(x_shard, w_shard, b_replica)

    # each row only has a mini-batch
    expected_out_shard = torch.chunk(out, chunks=2, dim=1)[row_id]
    assert torch.allclose(out_shard, expected_out_shard)

    ########################
    #  RRS0 x S0S1 -> RRS1 #
    ########################
    # w will be transposed in F.linear
    x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id]
    w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id]
    w_shard = torch.chunk(w_shard, chunks=2, dim=0)[column_id]
    b_shard = torch.chunk(b.clone(), chunks=2, dim=0)[column_id]

    # adding sharding spec
    x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]})
    w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1], 1: [0]})
    b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]})

    # check sharding spec
    assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
    assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, S0]"
    assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]"

    w_shard.pg_axis0 = col_process_group
    w_shard.pg_axis1 = row_process_group

    out_shard = F.linear(x_shard, w_shard, b_shard)

    # each row only has a mini-batch
    expected_out_shard = torch.chunk(out, chunks=2, dim=2)[column_id]
    assert torch.allclose(out_shard, expected_out_shard)


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@rerun_if_address_is_in_use()
def test_sharded_mlp(world_size):
228
    spawn(run_dist, world_size)
229
230
231
232


if __name__ == '__main__':
    test_sharded_mlp(4)