test_mpi_backend.py 3.17 KB
Newer Older
1
2
3
4
5
6
from mpi4py import MPI
import time
import torch
import torch.distributed as dist
import numpy as np
import deepspeed
Conglong Li's avatar
Conglong Li committed
7
8

from deepspeed.runtime.comm.mpi import MpiBackend
9
10
11
12
13

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

Conglong Li's avatar
Conglong Li committed
14
deepspeed.init_distributed(dist_backend='nccl')
15

Conglong Li's avatar
Conglong Li committed
16
17
# Change cuda_aware to True to test out CUDA-Aware MPI communication
backend = MpiBackend(cuda_aware=False)
18
19
20
21

device = torch.device('cuda', rank % torch.cuda.device_count())


Conglong Li's avatar
Conglong Li committed
22
# A simulated compression function using torch.distributed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def torch_sim(a):
    a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    scale = a.norm() / np.sqrt(a.numel())
    a_compressed = scale * a_sign
    a_sign = None
    worker_error = a - a_compressed
    dist.all_reduce(a_compressed)
    a_compressed.mul_(1 / dist.get_world_size())
    a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
    server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
    a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
    a_server_compressed = torch.cat(
        [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
    rank = dist.get_rank()
    server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
    torch.cuda.synchronize()
    torch.distributed.barrier()
    return a_server_compressed, worker_error, server_error


tensor_size = 100 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
    right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
    right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
Conglong Li's avatar
Conglong Li committed
51

52
53
54
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
Conglong Li's avatar
Conglong Li committed
55

56
57
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
Conglong Li's avatar
Conglong Li committed
58

59
60
61
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
torch.cuda.empty_cache()
local_rank = rank % torch.cuda.device_count()
Conglong Li's avatar
Conglong Li committed
62
63
64

a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)

65
66
67
68
69
70
71
threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch

Conglong Li's avatar
Conglong Li committed
72
73
test_correctness = True

74
75
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
Conglong Li's avatar
Conglong Li committed
76
77
78
if test_correctness:
    if torch.sum(diff_server_mask) == 0:
        print('Successfully passed the test for MPI Backend at Rank {}'.format(rank))
79
    else:
Conglong Li's avatar
Conglong Li committed
80
81
82
83
84
        check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
        if torch.sum(check_mag_mask) == 0:
            print('Successfully passed the test for MPI Backend at Rank {}'.format(rank))
        else:
            print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))