test_server_error.py 3.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from mpi4py import MPI
import time
import torch
import torch.distributed as dist
import numpy as np
import deepspeed
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

torch.distributed.init_process_group(backend='nccl',
                                     init_method='tcp://worker-0:2245',
                                     world_size=size,
                                     rank=rank)

dummy_model = [torch.nn.Parameter(torch.ones(10))]
dummy_optim = OnebitAdam(dummy_model, cuda_aware=False)

device = torch.device('cuda', rank % torch.cuda.device_count())


def torch_sim(a):
    a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    scale = a.norm() / np.sqrt(a.numel())
    a_compressed = scale * a_sign
    a_sign = None
    worker_error = a - a_compressed
    dist.all_reduce(a_compressed)
    a_compressed.mul_(1 / dist.get_world_size())
    a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
    server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
    a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
    a_server_compressed = torch.cat(
        [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
    rank = dist.get_rank()
    server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
    torch.cuda.synchronize()
    torch.distributed.barrier()
    return a_server_compressed, worker_error, server_error


# Input Tensor size
tensor_size = 100 * 2**20

server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
    right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
    right_tensor_size = tensor_size

right_server_size = right_tensor_size // size

# The -0.5 is required for avoiding sign flips/errors
a = torch.rand(tensor_size, device=device) - 0.5

worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
torch.cuda.empty_cache()
local_rank = rank % torch.cuda.device_count()

# Test the 1-bit Adam optimizer
a_after = dummy_optim.Compressed_Allreduce(a,
                                           worker_error,
                                           server_error,
                                           rank,
                                           size,
                                           comm,
                                           local_rank)

# If the error is below the threshold, it is acceptable for training
threshold = 1e-6

diff_pos = ((a_after - a_torch) > threshold)

if rank == 0:
    before_diff = torch.chunk(a_after - a_torch,
                              size)[rank] + server_error - server_error_torch
    if torch.norm(before_diff) / torch.norm(torch.chunk(a_after,
                                                        size)[rank]) < threshold:
        print('Successfully passed the test')
    else:
        print('The difference for the tensor before allgather is {}'.format(
            torch.norm(before_diff)))