test_mpi_perf.py 2.18 KB
Newer Older
aiss's avatar
aiss committed
1
2
'''Copyright The Microsoft DeepSpeed Team'''

Conglong Li's avatar
Conglong Li committed
3
4
5
6
7
8
9
10
from mpi4py import MPI
import torch
import deepspeed

from deepspeed.runtime.comm.mpi import MpiBackend

# Configure wall clock timer
from deepspeed.utils.timer import SynchronizedWallClockTimer
aiss's avatar
aiss committed
11
from deepspeed.accelerator import get_accelerator
Conglong Li's avatar
Conglong Li committed
12
13
14
15
16
17
18
19
20

from statistics import mean

timers = SynchronizedWallClockTimer()

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

aiss's avatar
aiss committed
21
deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
Conglong Li's avatar
Conglong Li committed
22
23
24
# Change cuda_aware to True to test out CUDA-Aware MPI communication
backend = MpiBackend(cuda_aware=False)

aiss's avatar
aiss committed
25
26
local_rank = rank % get_accelerator().device_count()
device = torch.device(get_accelerator().device_name(), local_rank)
Conglong Li's avatar
Conglong Li committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

tensor_size = 300 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
    right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
    right_tensor_size = tensor_size
right_server_size = right_tensor_size // size

# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank

worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)

warmup = 10
iters = 10

# Warmup
for i in range(warmup):
    backend.compressed_allreduce(a, worker_error, server_error, local_rank)

time_list = []

for i in range(iters):
    timers('compressed_allreduce').start()
    backend.compressed_allreduce(a, worker_error, server_error, local_rank)
    timers('compressed_allreduce').stop()
    time_list.append(timers('compressed_allreduce').elapsed())

timer_names = ['compressed_allreduce']
timers.log(names=timer_names, normalizer=1, memory_breakdown=None)

places = 2
convert = 1e3
float_size = 4

if rank == 0:
    for i in range(iters):
        lat = time_list[i]
        print("latency = ", lat * convert)

minlat = round(min(time_list) * convert)
maxlat = round(max(time_list) * convert)
meanlat = round(mean(time_list) * convert, places)
print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat))