test_mpi_backend.py 3.33 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5

6
7
from mpi4py import MPI
import torch
aiss's avatar
aiss committed
8
import deepspeed.comm as dist
9
10
import numpy as np
import deepspeed
Conglong Li's avatar
Conglong Li committed
11
12

from deepspeed.runtime.comm.mpi import MpiBackend
aiss's avatar
aiss committed
13
from deepspeed.accelerator import get_accelerator
14
15
16
17
18

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

aiss's avatar
aiss committed
19
deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
20

Conglong Li's avatar
Conglong Li committed
21
22
# Change cuda_aware to True to test out CUDA-Aware MPI communication
backend = MpiBackend(cuda_aware=False)
23

aiss's avatar
aiss committed
24
25
local_rank = rank % get_accelerator().device_count()
device = torch.device(get_accelerator().device_name(), local_rank)
26
27


aiss's avatar
aiss committed
28
# A simulated compression function using deepspeed.comm
29
30
31
32
33
34
35
36
37
38
39
40
def torch_sim(a):
    a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    scale = a.norm() / np.sqrt(a.numel())
    a_compressed = scale * a_sign
    a_sign = None
    worker_error = a - a_compressed
    dist.all_reduce(a_compressed)
    a_compressed.mul_(1 / dist.get_world_size())
    a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
    server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
    a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
aiss's avatar
aiss committed
41
    a_server_compressed = torch.cat([server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
42
43
    rank = dist.get_rank()
    server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
aiss's avatar
aiss committed
44
45
    get_accelerator().synchronize()
    dist.barrier()
46
47
48
49
50
51
52
53
54
55
    return a_server_compressed, worker_error, server_error


tensor_size = 100 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
    right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
    right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
Conglong Li's avatar
Conglong Li committed
56

57
58
59
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
Conglong Li's avatar
Conglong Li committed
60

61
62
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
Conglong Li's avatar
Conglong Li committed
63

64
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
aiss's avatar
aiss committed
65
get_accelerator().empty_cache()
Conglong Li's avatar
Conglong Li committed
66
67
68

a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)

69
70
71
72
73
74
75
threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch

Conglong Li's avatar
Conglong Li committed
76
77
test_correctness = True

78
79
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
Conglong Li's avatar
Conglong Li committed
80
81
82
if test_correctness:
    if torch.sum(diff_server_mask) == 0:
        print('Successfully passed the test for MPI Backend at Rank {}'.format(rank))
83
    else:
Conglong Li's avatar
Conglong Li committed
84
85
86
87
88
        check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
        if torch.sum(check_mag_mask) == 0:
            print('Successfully passed the test for MPI Backend at Rank {}'.format(rank))
        else:
            print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))