test_dist_adam.py 4.46 KB
Newer Older
1
import argparse
2
import os
3
4
5
6
7
8
import random

import torch
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam

class TestModel(torch.nn.Module):
9

10
11
    def __init__(self, args):
        super(TestModel, self).__init__()
12
13
14
15
        self.linear = torch.nn.Sequential(*[
            torch.nn.Linear(args.dim, args.dim)
            for _ in range(args.layers)
        ])
16
17

    def forward(self, x):
18
19
20
21
        y = 0
        for l in self.linear:
            y += l(x)
        return y
22
23
24

def setup(args):

25
26
27
    # Construct models with same parameters
    ref_model = TestModel(args).float().cuda()
    dist_model = TestModel(args).float().cuda()
28
    with torch.no_grad():
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        for ref_param, dist_param in zip(dist_model.parameters(),
                                         ref_model.parameters()):
            dist_param.data.copy_(ref_param.data)
    ref_model = torch.nn.parallel.DistributedDataParallel(
        ref_model,
        device_ids=[args.rank],
        output_device=args.rank,
    )

    # Construct optimizers with same hyperparameters
    optim_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
    ref_optim = torch.optim.Adam(
        [
            {'params': list(ref_model.parameters())[1::2], 'lr': 5e-3},
            {'params': list(ref_model.parameters())[0::2]},
        ],
        **optim_args,
    )
    dist_optim = DistributedFusedAdam(
        [
            {'params': list(dist_model.parameters())[1::2], 'lr': 5e-3},
            {'params': list(dist_model.parameters())[0::2]},
        ],
        bucket_cap_mb=71/(4*1024*1024),
        **optim_args,
    )

    return ref_model, ref_optim, dist_model, dist_optim
57
58
59
60
61

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--local_rank', type=int, default=-1)
62
63
64
65
    parser.add_argument('--steps', type=int, default=11)
    parser.add_argument('--batch', type=int, default=5)
    parser.add_argument('--dim', type=int, default=7)
    parser.add_argument('--layers', type=int, default=11)
66
    parser.add_argument('--atol', type=float, default=1e-3)
67
    parser.add_argument('--rtol', type=float, default=1e-3)
68
69
70
71
72
73

    args = parser.parse_args()

    return args

def setup_env(args):
74
75
76
77
78
79

    # Initialize NCCL
    local_rank = args.local_rank
    if local_rank < 0:
        local_rank = int(os.getenv('LOCAL_RANK', 0))
    torch.cuda.set_device(local_rank % torch.cuda.device_count())
80
81
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    args.rank = torch.distributed.get_rank()
82
    args.world_size = torch.distributed.get_world_size()
83

84
85
    # Initialize RNG
    seed = 42 + args.rank
86
87
88
89
90
91
92
93
94
95
    random.seed(seed)
    torch.manual_seed(seed)

    return args

def main():
    args = parse_args()
    args = setup_env(args)
    torch.set_printoptions(precision=16)

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    def assert_allclose(ref_x, dist_x, message):
        message = (
            f'Rank {args.rank}: {message}\n'
            f'Reference Adam: {ref_x}\n'
            f'Distributed Adam: {dist_x}\n'
            f'Relative error: {torch.abs((ref_x-dist_x)/ref_x)}\n'
        )
        assert torch.allclose(ref_x, dist_x, atol=args.atol, rtol=args.rtol), message

    # Train model with data-parallelism and ZeRO
    ref_model, ref_optim, dist_model, dist_optim = setup(args)
    for step in range(args.steps):

        # Synthetic data
        x = torch.randn(args.batch, args.dim).cuda()
        dy = torch.randn_like(x).cuda()

        # Reference implementation
        ref_optim.zero_grad()
        x_ref = x.detach().clone().requires_grad_(True)
        y_ref = ref_model(x_ref)
        y_ref.backward(dy)
        ref_optim.step()
119

120
121
122
        # Distributed implementation
        dist_optim.zero_grad()
        x_dist = x.detach().clone().requires_grad_(True)
123
124
        y_dist = dist_model(x_dist)
        y_dist.backward(dy)
125
        dist_optim.step()
126

127
        # Check values
128
129
        torch.cuda.synchronize()
        torch.distributed.barrier()
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        assert_allclose(
            y_ref,
            y_dist,
            f'inconsistent output in step {step}',
        )
        assert_allclose(
            x_ref.grad,
            x_dist.grad,
            f'inconsistent input grad in step {step}',
        )
        for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(),
                                                        dist_model.parameters())):
            assert_allclose(
                ref_param,
                dist_param,
                f'inconsistent param {i} in step {step}',
            )
147
148
149

if __name__ == "__main__":
    main()