ddp_race_condition_test.py 2.27 KB
Newer Older
1
2
3
4
5
6
import torch
import torch.distributed as dist
from torch.nn import Parameter
from torch.nn import Module
from apex.parallel import DistributedDataParallel as DDP
import argparse
mcarilli's avatar
mcarilli committed
7
import os
8
9
10


parser = argparse.ArgumentParser(description='allreduce hook example')
mcarilli's avatar
mcarilli committed
11
parser.add_argument("--local_rank", default=0, type=int)
12
13
args = parser.parse_args()

mcarilli's avatar
mcarilli committed
14
15
16
args.distributed = False
if 'WORLD_SIZE' in os.environ:
    args.distributed = int(os.environ['WORLD_SIZE']) > 1
17
18

if args.distributed:
mcarilli's avatar
mcarilli committed
19
20
21
22
23
    args.gpu = args.local_rank % torch.cuda.device_count()
    torch.cuda.set_device(args.gpu)
    torch.distributed.init_process_group(backend='nccl',
                                         init_method='env://')
    args.world_size = torch.distributed.get_world_size()
24

25
torch.set_printoptions(precision=10)
mcarilli's avatar
mcarilli committed
26
torch.manual_seed(args.local_rank)
27
28
29
30

class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
31
32
        self.a = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(1.0))
        self.b = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(2.0))
33
    def forward(self, input):
34
35
        return (input*self.a)*self.b

36
37
38
39
model = Model()
# model = DDP(model, message_size=1, gradient_average_split_factor=2.0)
# model = DDP(model, delay_allreduce=True)
model = DDP(model, message_size=1, allreduce_trigger_params=[model.b])
40

41
x = torch.cuda.FloatTensor(4096*4096)
42

43
passed = True
44
torch.cuda.cudart().cudaProfilerStart()
45
for i in range(10):
mcarilli's avatar
mcarilli committed
46
    x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity
47
    model.zero_grad()
48
    out = model(x)
49
    loss = out.sum()
50
    torch.cuda.nvtx.range_push("backward")
51
    loss.backward()
52
    torch.cuda.nvtx.range_pop()
53
    
54
    # torch.cuda.nvtx.range_push("synchronize() + info")
mcarilli's avatar
mcarilli committed
55
    # torch.cuda.synchronize()
56
57
    print("i = {}".format(i))
    def info(name, param, val):
58
59
        expected = val*4096*4096*(2.*i+1)/2.
        actual = param.grad.data.sum().item()
60
        print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format(
61
62
63
64
65
              param.grad.data_ptr(), expected, actual))
        return (expected == actual)
    if not info("model.a", model.module.a, 2.):  passed = False
    if not info("model.b", model.module.b, 1.):  passed = False
    # torch.cuda.nvtx.range_pop()
66
torch.cuda.cudart().cudaProfilerStop()
67
68

print("passed = ", passed)