test_dist_adam.py 5.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import argparse
import random
import sys

import torch
from torch.nn.parallel import DistributedDataParallel as DDP

from apex import amp
from apex.optimizers import FusedAdam
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam


class TestModel(torch.nn.Module):
    def __init__(self, args):
        super(TestModel, self).__init__()

        self.linear = torch.nn.Sequential(*[torch.nn.Linear(args.dim, args.dim, bias=args.bias) for _ in range(args.layers)])

    def forward(self, x):
        return self.linear(x)

def setup(args):
    ## Model
    ref_model = TestModel(args).cuda()
    dist_model = TestModel(args).cuda()

    # Same weights
    with torch.no_grad():
        for dp, rp in zip(dist_model.parameters(), ref_model.parameters()):
            dp.data.copy_(rp.data)

    dist_model = dist_model.half()


    ## Optimizer
    # same hyperparameters
    ref_opt_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
    ref_opt = FusedAdam(ref_model.parameters(), **ref_opt_args)

    dist_opt_args = ref_opt_args.copy()
    dist_opt_args.update( {'overlap_reductions' : False} )
    dist_opt_args.update( {'process_group_size' : args.n_gpu} )
    dist_opt_args.update( {'dwu_group_size' : args.dwu_group_size} )
    dist_opt_args.update( {'dwu_num_blocks' : 1} )
    dist_opt_args.update( {'dwu_num_chunks' : 1} )
    dist_opt = DistributedFusedAdam(dist_model.parameters(), **dist_opt_args)
    dist_opt.set_global_scale(1.)
    
    ## amp-init
    amp_args = { 'loss_scale' : 'dynamic' , 'opt_level' : 'O2'}
    ref_model, ref_opt = amp.initialize(ref_model, ref_opt, **amp_args)
    
   
    ## DDP
    ref_model = DDP(ref_model, device_ids=[args.rank])
    with torch.no_grad():
        for dp in dist_model.parameters():
             torch.distributed.broadcast(dp.data, src=0)
        for rp in ref_model.parameters():
            torch.distributed.broadcast(rp.data, src=0)
    torch.cuda.synchronize()
    torch.distributed.barrier()
    if get_rank() == 0:
        print(f'dist opt with {args.n_gpu} GPUs')

    return ref_model, ref_opt, dist_model, dist_opt

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--steps', type=int, default=20)
    parser.add_argument('--batch', type=int, default=32)
    parser.add_argument('--dim', type=int, default=4)
    parser.add_argument('--layers', type=int, default=2)
    parser.add_argument('--bias', action='store_true')
    parser.add_argument('--atol', type=float, default=1e-3)
    parser.add_argument('--rtol', type=float, default=1)
    parser.add_argument('--dwu_group_size', type=float, default=1)

    args = parser.parse_args()

    return args

def setup_env(args):
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    args.rank = torch.distributed.get_rank()
    args.n_gpu = torch.distributed.get_world_size()

    seed = 42 + get_rank()

    random.seed(seed)
    torch.manual_seed(seed)

    return args

def get_rank():
    return torch.distributed.get_rank()

def main():
    args = parse_args()
    args = setup_env(args)
    tol_args = { 'atol' : args.atol, 'rtol' : args.rtol }

    torch.set_printoptions(precision=16)

    ref_model, ref_opt, dist_model, dist_opt = setup(args)

    # lazy_init not called yet, initialize stash
    stash = ref_opt._amp_stash
    stash.all_fp16_params, stash.all_fp32_from_fp16_params = [], []

    # make sure everything from _first_step_init_ is ready before training
    # e.g. registering allreduce_hook
    # so that gradients are copied/reduced when necessary
    dist_opt._init_everything()

    for i in range(args.steps):
        x_ref = torch.randn(args.batch, args.dim, dtype=torch.half).cuda().requires_grad_(True)
        x_dist = x_ref.clone().detach().requires_grad_(True)
        
        if get_rank() == 0:
            print(f'[{i}] Checking input')
        #print("x_ref:", x_ref.flatten()[:10])
        #print("x_dist:", x_dist.flatten()[:10])
        assert(torch.allclose(x_ref, x_dist, **tol_args))

        y_ref = ref_model(x_ref).half()
        y_dist = dist_model(x_dist)

        if get_rank() == 0:
            print(f'[{i}] Checking output')
        #print("y_ref:", y_ref.flatten()[:10])
        #print("y_dist:", y_dist.flatten()[:10])
        assert(torch.allclose(y_ref, y_dist, **tol_args))

        dy = torch.randn_like(y_ref)

        y_ref.backward(dy)
        y_dist.backward(dy)

        if get_rank() == 0:
            print(f'[{i}] Checking gradients')
        torch.distributed.barrier()
        torch.cuda.synchronize()
        assert(torch.allclose(x_ref.grad, x_dist.grad, **tol_args))

        # gradient all-reduce within distributed optimizer
        dist_opt.complete_reductions()

        if get_rank() == 0:
            print(f'[{i}] Stepping')
        ref_opt.step()
        dist_opt.step()

        torch.cuda.synchronize()
        torch.distributed.barrier()
        print('Checking new weights')
        if get_rank() == 0:
            print("ref param:", ref_model.module.linear[0].weight)
            print("dist param:", dist_model.linear[0].weight)
        
        for i, (rp, dp) in enumerate(zip(ref_model.parameters(), dist_model.parameters())):
            if not torch.allclose(rp, dp, **tol_args):
                if get_rank() == 0:
                    print(f'Rank: {get_rank()}, Param: {i}')
                    print(f'ref: {rp.sum().item()}, dist: {dp.sum().item()}')
                    print(rp)
                    print(dp)
    
                    print(torch.abs(rp-dp) > tol_args['atol'])
                    sys.exit(0)

        # zero grads
        for rp, dp in zip(ref_model.parameters(), dist_model.parameters()):
            rp.grad = None
            dp.grad = None


if __name__ == "__main__":
    main()