test_tensor_parallel.py 15.1 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import os

import pytest
import torch
from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use
from nanotron import distributed as dist
from nanotron.distributed import get_global_rank
from nanotron.parallel import ParallelContext
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.nn import (
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
)
from torch import nn as torch_nn


@pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)])
@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode))
@pytest.mark.parametrize("async_communication", [False, True])
@pytest.mark.parametrize("tp_recompute_allgather", [False, True])
@rerun_if_address_is_in_use()
def test_column_linear(
    tp: int,
    dp: int,
    pp: int,
    tp_mode: TensorParallelLinearMode,
    async_communication: bool,
    tp_recompute_allgather: bool,
):
    if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication:
        pytest.skip("ALL_REDUCE mode does not support async communication")
    if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather:
        pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather")
    init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)(
        tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather
    )


def _test_column_linear(
    parallel_context: ParallelContext,
    tp_mode: TensorParallelLinearMode,
    async_communication: bool,
    tp_recompute_allgather: bool,
):
    if async_communication:
        os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
    in_features = 2
    out_features_per_tp_rank = 3
    out_features = parallel_context.tp_pg.size() * out_features_per_tp_rank

    # Sharded
    column_linear = TensorParallelColumnLinear(
        in_features=in_features,
        out_features=out_features,
        pg=parallel_context.tp_pg,
        mode=tp_mode,
        device="cuda",
        async_communication=async_communication,
        tp_recompute_allgather=tp_recompute_allgather,
    )

    # Un-sharded
    reference_linear = torch_nn.Linear(in_features=in_features, out_features=out_features, device="cuda")

    # Copy weights/bias from sharded to un-sharded
    with torch.inference_mode():
        dist.all_gather(
            tensor_list=list(reference_linear.weight.split(out_features_per_tp_rank, dim=0)),
            tensor=column_linear.weight,
            group=parallel_context.tp_pg,
        )
        dist.all_gather(
            tensor_list=list(reference_linear.bias.split(out_features_per_tp_rank, dim=0)),
            tensor=column_linear.bias,
            group=parallel_context.tp_pg,
        )

    # Generate random input
    random_input: torch.Tensor
    sharded_random_input: torch.Tensor
    if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
        batch_size = 5
        random_input = torch.randn(batch_size, in_features, device="cuda")
        # synchronize random_input across tp
        dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg)
        sharded_random_input = random_input
    elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
        sharded_batch_size = 5
        sharded_random_input = torch.randn(sharded_batch_size, in_features, device="cuda")
        if parallel_context.tp_pg.size() > 1:
            random_input = torch.empty(
                sharded_batch_size * parallel_context.tp_pg.size(),
                *(sharded_random_input.shape[1:]),
                device=sharded_random_input.device,
                dtype=sharded_random_input.dtype,
            )
            dist.all_gather_into_tensor(random_input, sharded_random_input, group=parallel_context.tp_pg)
        else:
            random_input = sharded_random_input
    else:
        ValueError(f"Unsupported mode: {tp_mode}")
    # It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage
    sharded_random_input = sharded_random_input.clone()
    random_input.requires_grad = True
    sharded_random_input.requires_grad = True

    # Test that we get the same output after forward pass
    sharded_output = column_linear(sharded_random_input)
    reference_output = reference_linear(random_input)
    # TODO @thomasw21: Tune tolerance
    try:
        torch.testing.assert_close(
            sharded_output,
            reference_output[
                :,
                dist.get_rank(parallel_context.tp_pg)
                * out_features_per_tp_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
                * out_features_per_tp_rank,
            ],
        )
    except BaseException as e:
        print(f"Rank {dist.get_rank(parallel_context.tp_pg)}: FAIL.")
        dist.barrier()
        raise e
    print(f"Rank {dist.get_rank(parallel_context.tp_pg)}: SUCCESS.")
    dist.barrier()

    # Test that we get the same gradient after backward pass
    sharded_output.sum().backward()
    reference_output.sum().backward()
    hidden_dim_slice = slice(
        dist.get_rank(parallel_context.tp_pg) * out_features_per_tp_rank,
        (dist.get_rank(parallel_context.tp_pg) + 1) * out_features_per_tp_rank,
    )
    torch.testing.assert_close(
        column_linear.weight.grad,
        reference_linear.weight.grad[hidden_dim_slice],
    )
    torch.testing.assert_close(
        column_linear.bias.grad,
        reference_linear.bias.grad[hidden_dim_slice],
    )
    if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
        torch.testing.assert_close(
            sharded_random_input.grad,
            random_input.grad,
        )
    elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
        batch_dim_slice = slice(
            dist.get_rank(parallel_context.tp_pg) * sharded_batch_size,
            (dist.get_rank(parallel_context.tp_pg) + 1) * sharded_batch_size,
        )
        torch.testing.assert_close(
            sharded_random_input.grad,
            random_input.grad[batch_dim_slice],
        )
    else:
        ValueError(f"Unsupported mode: {tp_mode}")

    parallel_context.destroy()


@pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)])
@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode))
@pytest.mark.parametrize("async_communication", [False, True])
@pytest.mark.parametrize("tp_recompute_allgather", [False, True])
@rerun_if_address_is_in_use()
def test_row_linear(
    tp: int,
    dp: int,
    pp: int,
    tp_mode: TensorParallelLinearMode,
    async_communication: bool,
    tp_recompute_allgather: bool,
):
    if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication:
        pytest.skip("ALL_REDUCE mode does not support async communication")
    if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather:
        pytest.skip("ALL_REDUCE mode is not affected by tp_recompute_allgather")

    init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(
        tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather
    )


def _test_row_linear(
    parallel_context: ParallelContext,
    tp_mode: TensorParallelLinearMode,
    async_communication: bool,
    tp_recompute_allgather: bool,
):
    if async_communication:
        os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
    out_features = 3
    in_features_per_rank = 2
    in_features = parallel_context.tp_pg.size() * in_features_per_rank

    # Sharded
    row_linear = TensorParallelRowLinear(
        in_features=in_features,
        out_features=out_features,
        pg=parallel_context.tp_pg,
        mode=tp_mode,
        device="cuda",
        async_communication=async_communication,
    )

    # Un-sharded
    reference_linear = torch_nn.Linear(in_features=in_features, out_features=out_features, device="cuda")

    # Copy weights/bias from sharded to un-sharded
    with torch.inference_mode():
        dist.all_reduce(tensor=reference_linear.weight, op=dist.ReduceOp.SUM, group=parallel_context.tp_pg)
        row_linear.weight.copy_(
            reference_linear.weight[
                :,
                dist.get_rank(parallel_context.tp_pg)
                * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
                * in_features_per_rank,
            ]
        )
        # broadcast bias from rank 0, and the other don't have bias
        if dist.get_rank(parallel_context.tp_pg) == 0:
            row_linear.bias.copy_(reference_linear.bias)
        dist.broadcast(
            tensor=reference_linear.bias,
            src=get_global_rank(group=parallel_context.tp_pg, group_rank=0),
            group=parallel_context.tp_pg,
        )

    # Generate random input
    if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
        batch_size = 5
    elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
        batch_size = 5 * parallel_context.tp_pg.size()
    else:
        raise ValueError()
    random_input = torch.randn(batch_size, in_features, device="cuda")
    # synchronize random_input across tp
    dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg)
    random_input.requires_grad = True
    # Row linear receives as input sharded input
    random_sharded_input = (
        random_input[
            :,
            dist.get_rank(parallel_context.tp_pg)
            * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
            * in_features_per_rank,
        ]
        .detach()
        .clone()
    )
    random_sharded_input.requires_grad = True

    # Test that we get the same output after forward pass
    # TODO @kunhao: We may want to have our custom error type
    sharded_output = row_linear(random_sharded_input)
    reference_output = reference_linear(random_input)

    if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
        sharded_reference_output = reference_output
    elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
        assert batch_size % parallel_context.tp_pg.size() == 0
        sharded_batch_size = batch_size // parallel_context.tp_pg.size()
        sharded_reference_output = reference_output[
            dist.get_rank(parallel_context.tp_pg)
            * sharded_batch_size : (dist.get_rank(parallel_context.tp_pg) + 1)
            * sharded_batch_size
        ]
    else:
        raise ValueError(f"Unsupported mode: {tp_mode}")

    # TODO @thomasw21: Tune tolerance
    torch.testing.assert_close(
        sharded_output,
        sharded_reference_output,
    )

    # Test that we get the same gradient after backward pass
    sharded_output.sum().backward()
    reference_output.sum().backward()
    torch.testing.assert_close(
        row_linear.weight.grad,
        reference_linear.weight.grad[
            :,
            dist.get_rank(parallel_context.tp_pg)
            * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
            * in_features_per_rank,
        ],
    )
    if dist.get_rank(parallel_context.tp_pg) == 0:
        torch.testing.assert_close(
            row_linear.bias.grad,
            reference_linear.bias.grad,
        )
    else:
        assert row_linear.bias is None

    torch.testing.assert_close(
        random_sharded_input.grad,
        random_input.grad[
            :,
            dist.get_rank(parallel_context.tp_pg)
            * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
            * in_features_per_rank,
        ],
    )

    parallel_context.destroy()


@pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)])
@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode))
@rerun_if_address_is_in_use()
def test_tensor_parallel_embedding(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode):
    init_distributed(tp=tp, dp=dp, pp=pp)(_test_tensor_parallel_embedding)(tp_mode=tp_mode)


def _test_tensor_parallel_embedding(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode):
    num_embeddings_per_rank = 100
    embedding_dim = 3
    num_embeddings = parallel_context.tp_pg.size() * num_embeddings_per_rank

    # Sharded
    sharded_embedding = TensorParallelEmbedding(
        num_embeddings=num_embeddings,
        embedding_dim=embedding_dim,
        pg=parallel_context.tp_pg,
        mode=tp_mode,
        device="cuda",
    )

    # Un-sharded
    reference_embedding = torch_nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, device="cuda")

    # Copy weights/bias from sharded to un-sharded
    with torch.inference_mode():
        dist.all_reduce(tensor=reference_embedding.weight, op=dist.ReduceOp.SUM, group=parallel_context.tp_pg)
        sharded_embedding.weight.copy_(
            reference_embedding.weight[
                dist.get_rank(parallel_context.tp_pg)
                * num_embeddings_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
                * num_embeddings_per_rank,
                :,
            ]
        )

    # Generate random input
    random_input: torch.Tensor
    if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
        batch_size = 5
    elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
        batch_size = 5 * parallel_context.tp_pg.size()
    else:
        raise ValueError(f"Unsupported mode: {tp_mode}")
    random_input = torch.randint(low=0, high=num_embeddings, size=(batch_size,), device="cuda")
    dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg)

    # Test that we get the same output after forward pass
    sharded_output = sharded_embedding(random_input)
    reference_output = reference_embedding(random_input)
    weights = torch.arange(batch_size, device="cuda")[:, None]

    if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
        sharded_reference_output = reference_output
        sharded_weights = weights
    elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
        assert batch_size % parallel_context.tp_pg.size() == 0
        sharded_batch_size = batch_size // parallel_context.tp_pg.size()
        sharded_reference_output = reference_output[
            dist.get_rank(parallel_context.tp_pg)
            * sharded_batch_size : (dist.get_rank(parallel_context.tp_pg) + 1)
            * sharded_batch_size
        ]
        sharded_weights = weights[
            dist.get_rank(parallel_context.tp_pg)
            * sharded_batch_size : (dist.get_rank(parallel_context.tp_pg) + 1)
            * sharded_batch_size
        ]
    else:
        raise ValueError(f"Unsupported mode: {tp_mode}")

    # TODO @thomasw21: Tune tolerance
    torch.testing.assert_close(sharded_output, sharded_reference_output, atol=0, rtol=0)

    # Test that we get the same gradient after backward pass
    (sharded_output * sharded_weights).sum().backward()
    (reference_output * weights).sum().backward()
    torch.testing.assert_close(
        sharded_embedding.weight.grad,
        reference_embedding.weight.grad[
            dist.get_rank(parallel_context.tp_pg)
            * num_embeddings_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
            * num_embeddings_per_rank,
            :,
        ],
        atol=0,
        rtol=0,
    )

    parallel_context.destroy()