test_nccl_backend.py 3.41 KB
Newer Older
aiss's avatar
aiss committed
1
2
'''Copyright The Microsoft DeepSpeed Team'''

3
import torch
aiss's avatar
aiss committed
4
import deepspeed.comm as dist
5
import numpy as np
Conglong Li's avatar
Conglong Li committed
6
import argparse
7
import deepspeed
Conglong Li's avatar
Conglong Li committed
8
import os
9

Conglong Li's avatar
Conglong Li committed
10
from deepspeed.runtime.comm.nccl import NcclBackend
aiss's avatar
aiss committed
11
from deepspeed.accelerator import get_accelerator
12

Conglong Li's avatar
Conglong Li committed
13
14
15
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
args = parser.parse_args()
16

aiss's avatar
aiss committed
17
deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
Conglong Li's avatar
Conglong Li committed
18
args.local_rank = int(os.environ['LOCAL_RANK'])
19

aiss's avatar
aiss committed
20
21
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
22

Conglong Li's avatar
Conglong Li committed
23
24
size = dist.get_world_size()
rank = dist.get_rank()
25

Conglong Li's avatar
Conglong Li committed
26
27
backend = NcclBackend()
local_rank = args.local_rank
28

Conglong Li's avatar
Conglong Li committed
29

aiss's avatar
aiss committed
30
# A simulated compression function using deepspeed.comm
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def torch_sim(a):
    a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    scale = a.norm() / np.sqrt(a.numel())
    a_compressed = scale * a_sign
    a_sign = None
    worker_error = a - a_compressed
    dist.all_reduce(a_compressed)
    a_compressed.mul_(1 / dist.get_world_size())
    a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
    server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
    a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
    a_server_compressed = torch.cat(
        [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
    rank = dist.get_rank()
    server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
aiss's avatar
aiss committed
47
48
    get_accelerator().synchronize()
    dist.barrier()
49
50
51
    return a_server_compressed, worker_error, server_error


Conglong Li's avatar
Conglong Li committed
52
tensor_size = 300 * 2**20
53
54
55
56
57
58
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
    right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
    right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
Conglong Li's avatar
Conglong Li committed
59

60
61
62
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
Conglong Li's avatar
Conglong Li committed
63

64
65
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
Conglong Li's avatar
Conglong Li committed
66

67
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
aiss's avatar
aiss committed
68
get_accelerator().empty_cache()
Conglong Li's avatar
Conglong Li committed
69
70
71

a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)

72
73
74
75
76
77
78
threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch

Conglong Li's avatar
Conglong Li committed
79
80
test_correctness = True

81
82
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
Conglong Li's avatar
Conglong Li committed
83
84
85
if test_correctness:
    if torch.sum(diff_server_mask) == 0:
        print('Successfully passed the test for NCCL Backend at Rank {}'.format(rank))
86
    else:
Conglong Li's avatar
Conglong Li committed
87
88
89
90
91
92
        check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
        if torch.sum(check_mag_mask) == 0:
            print(
                'Successfully passed the test for NCCL Backend at Rank {}'.format(rank))
        else:
            print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))