test_nccl_backend.py 3.43 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5

6
import torch
aiss's avatar
aiss committed
7
import deepspeed.comm as dist
8
import numpy as np
Conglong Li's avatar
Conglong Li committed
9
import argparse
10
import deepspeed
Conglong Li's avatar
Conglong Li committed
11
import os
12

Conglong Li's avatar
Conglong Li committed
13
from deepspeed.runtime.comm.nccl import NcclBackend
aiss's avatar
aiss committed
14
from deepspeed.accelerator import get_accelerator
15

Conglong Li's avatar
Conglong Li committed
16
17
18
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
args = parser.parse_args()
19

aiss's avatar
aiss committed
20
deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
Conglong Li's avatar
Conglong Li committed
21
args.local_rank = int(os.environ['LOCAL_RANK'])
22

aiss's avatar
aiss committed
23
24
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
25

Conglong Li's avatar
Conglong Li committed
26
27
size = dist.get_world_size()
rank = dist.get_rank()
28

Conglong Li's avatar
Conglong Li committed
29
30
backend = NcclBackend()
local_rank = args.local_rank
31

Conglong Li's avatar
Conglong Li committed
32

aiss's avatar
aiss committed
33
# A simulated compression function using deepspeed.comm
34
35
36
37
38
39
40
41
42
43
44
45
def torch_sim(a):
    a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    scale = a.norm() / np.sqrt(a.numel())
    a_compressed = scale * a_sign
    a_sign = None
    worker_error = a - a_compressed
    dist.all_reduce(a_compressed)
    a_compressed.mul_(1 / dist.get_world_size())
    a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
    server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
    a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
aiss's avatar
aiss committed
46
    a_server_compressed = torch.cat([server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
47
48
    rank = dist.get_rank()
    server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
aiss's avatar
aiss committed
49
50
    get_accelerator().synchronize()
    dist.barrier()
51
52
53
    return a_server_compressed, worker_error, server_error


Conglong Li's avatar
Conglong Li committed
54
tensor_size = 300 * 2**20
55
56
57
58
59
60
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
    right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
    right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
Conglong Li's avatar
Conglong Li committed
61

62
63
64
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
Conglong Li's avatar
Conglong Li committed
65

66
67
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
Conglong Li's avatar
Conglong Li committed
68

69
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
aiss's avatar
aiss committed
70
get_accelerator().empty_cache()
Conglong Li's avatar
Conglong Li committed
71
72
73

a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)

74
75
76
77
78
79
80
threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch

Conglong Li's avatar
Conglong Li committed
81
82
test_correctness = True

83
84
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
Conglong Li's avatar
Conglong Li committed
85
86
87
if test_correctness:
    if torch.sum(diff_server_mask) == 0:
        print('Successfully passed the test for NCCL Backend at Rank {}'.format(rank))
88
    else:
Conglong Li's avatar
Conglong Li committed
89
90
        check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
        if torch.sum(check_mag_mask) == 0:
aiss's avatar
aiss committed
91
            print('Successfully passed the test for NCCL Backend at Rank {}'.format(rank))
Conglong Li's avatar
Conglong Li committed
92
93
        else:
            print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))