ddp_race_condition_test.py 2.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.distributed as dist
from torch.nn import Parameter
from torch.nn import Module
from apex.parallel import DistributedDataParallel as DDP
import argparse


parser = argparse.ArgumentParser(description='allreduce hook example')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--world-size', default=1, type=int,
                    help='Number of GPUs to use. Can either be manually set ' +
                    'or automatically set by using \'python -m multiproc\'.')
parser.add_argument('--rank', default=0, type=int,
                    help='Used for multi-process training. Can either be manually set ' +
                    'or automatically set by using \'python -m multiproc\'.')

args = parser.parse_args()

args.distributed = args.world_size > 1

if args.distributed:
    torch.cuda.set_device(args.rank % torch.cuda.device_count())
    dist.init_process_group(args.dist_backend, init_method=args.dist_url,
                            world_size=args.world_size)
    rank = torch.distributed.get_rank()
torch.set_printoptions(precision=10)

class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
        self.x = Parameter(torch.cuda.FloatTensor(1,4096*4096).fill_(1.0))
    def forward(self, input):
        return self.x*input
model = DDP(Model(), message_size=1)

z = torch.cuda.FloatTensor(4096*4096)

for i in range(10):
    z.fill_(i + rank) # fill z with new values every iteration for sanity
    model.zero_grad()
    out = model(z)
    loss = out.sum()
    torch.cuda.nvtx.range_push("backward")
    loss.backward()
    torch.cuda.nvtx.range_pop()
    
    torch.cuda.nvtx.range_push("synchronize() + sum")
    torch.cuda.synchronize()
    for param in model.parameters():
        print("i = {},\n"
              "param.grad.data_ptr() = {}\n"
              "expected {},\n" 
              "     got {}\n"
              .format(i,
                      param.grad.data_ptr(),
                      4096*4096*(2.*i+1)/2.,
                      param.grad.data.sum().item()))
    torch.cuda.nvtx.range_pop()