Commit e2af089c authored by Tim Moon's avatar Tim Moon
Browse files

Update dist Adam test to use updated API

parent 6e412916
import argparse
import os
import random
import sys
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from apex import amp
from apex.optimizers import FusedAdam
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
class TestModel(torch.nn.Module):
def __init__(self, args):
super(TestModel, self).__init__()
self.linear = torch.nn.Sequential(*[torch.nn.Linear(args.dim, args.dim, bias=args.bias) for _ in range(args.layers)])
self.linear = torch.nn.Sequential(*[
torch.nn.Linear(args.dim, args.dim)
for _ in range(args.layers)
])
def forward(self, x):
return self.linear(x)
y = 0
for l in self.linear:
y += l(x)
return y
def setup(args):
## Model
ref_model = TestModel(args).cuda()
dist_model = TestModel(args).cuda()
# Same weights
# Construct models with same parameters
ref_model = TestModel(args).float().cuda()
dist_model = TestModel(args).float().cuda()
with torch.no_grad():
for dp, rp in zip(dist_model.parameters(), ref_model.parameters()):
dp.data.copy_(rp.data)
dist_model = dist_model.half()
## Optimizer
# same hyperparameters
ref_opt_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
ref_opt = FusedAdam(ref_model.parameters(), **ref_opt_args)
dist_opt_args = ref_opt_args.copy()
dist_opt_args.update( {'overlap_reductions' : False} )
dist_opt_args.update( {'process_group_size' : args.n_gpu} )
dist_opt_args.update( {'dwu_group_size' : args.dwu_group_size} )
dist_opt_args.update( {'dwu_num_blocks' : 1} )
dist_opt_args.update( {'dwu_num_chunks' : 1} )
dist_opt = DistributedFusedAdam(dist_model.parameters(), **dist_opt_args)
dist_opt.set_global_scale(1.)
## amp-init
amp_args = { 'loss_scale' : 'dynamic' , 'opt_level' : 'O2'}
ref_model, ref_opt = amp.initialize(ref_model, ref_opt, **amp_args)
## DDP
ref_model = DDP(ref_model, device_ids=[args.rank])
with torch.no_grad():
for dp in dist_model.parameters():
torch.distributed.broadcast(dp.data, src=0)
for rp in ref_model.parameters():
torch.distributed.broadcast(rp.data, src=0)
torch.cuda.synchronize()
torch.distributed.barrier()
if get_rank() == 0:
print(f'dist opt with {args.n_gpu} GPUs')
return ref_model, ref_opt, dist_model, dist_opt
for ref_param, dist_param in zip(dist_model.parameters(),
ref_model.parameters()):
dist_param.data.copy_(ref_param.data)
ref_model = torch.nn.parallel.DistributedDataParallel(
ref_model,
device_ids=[args.rank],
output_device=args.rank,
)
# Construct optimizers with same hyperparameters
optim_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
ref_optim = torch.optim.Adam(
[
{'params': list(ref_model.parameters())[1::2], 'lr': 5e-3},
{'params': list(ref_model.parameters())[0::2]},
],
**optim_args,
)
dist_optim = DistributedFusedAdam(
[
{'params': list(dist_model.parameters())[1::2], 'lr': 5e-3},
{'params': list(dist_model.parameters())[0::2]},
],
bucket_cap_mb=71/(4*1024*1024),
**optim_args,
)
return ref_model, ref_optim, dist_model, dist_optim
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--steps', type=int, default=20)
parser.add_argument('--batch', type=int, default=32)
parser.add_argument('--dim', type=int, default=4)
parser.add_argument('--layers', type=int, default=2)
parser.add_argument('--bias', action='store_true')
parser.add_argument('--steps', type=int, default=11)
parser.add_argument('--batch', type=int, default=5)
parser.add_argument('--dim', type=int, default=7)
parser.add_argument('--layers', type=int, default=11)
parser.add_argument('--atol', type=float, default=1e-3)
parser.add_argument('--rtol', type=float, default=1)
parser.add_argument('--dwu_group_size', type=float, default=1)
parser.add_argument('--rtol', type=float, default=1e-3)
args = parser.parse_args()
return args
def setup_env(args):
torch.cuda.set_device(args.local_rank)
# Initialize NCCL
local_rank = args.local_rank
if local_rank < 0:
local_rank = int(os.getenv('LOCAL_RANK', 0))
torch.cuda.set_device(local_rank % torch.cuda.device_count())
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.rank = torch.distributed.get_rank()
args.n_gpu = torch.distributed.get_world_size()
seed = 42 + get_rank()
args.world_size = torch.distributed.get_world_size()
# Initialize RNG
seed = 42 + args.rank
random.seed(seed)
torch.manual_seed(seed)
return args
def get_rank():
return torch.distributed.get_rank()
def main():
args = parse_args()
args = setup_env(args)
tol_args = { 'atol' : args.atol, 'rtol' : args.rtol }
torch.set_printoptions(precision=16)
ref_model, ref_opt, dist_model, dist_opt = setup(args)
# lazy_init not called yet, initialize stash
stash = ref_opt._amp_stash
stash.all_fp16_params, stash.all_fp32_from_fp16_params = [], []
# make sure everything from _first_step_init_ is ready before training
# e.g. registering allreduce_hook
# so that gradients are copied/reduced when necessary
dist_opt._init_everything()
for i in range(args.steps):
x_ref = torch.randn(args.batch, args.dim, dtype=torch.half).cuda().requires_grad_(True)
x_dist = x_ref.clone().detach().requires_grad_(True)
if get_rank() == 0:
print(f'[{i}] Checking input')
#print("x_ref:", x_ref.flatten()[:10])
#print("x_dist:", x_dist.flatten()[:10])
assert(torch.allclose(x_ref, x_dist, **tol_args))
def assert_allclose(ref_x, dist_x, message):
message = (
f'Rank {args.rank}: {message}\n'
f'Reference Adam: {ref_x}\n'
f'Distributed Adam: {dist_x}\n'
f'Relative error: {torch.abs((ref_x-dist_x)/ref_x)}\n'
)
assert torch.allclose(ref_x, dist_x, atol=args.atol, rtol=args.rtol), message
# Train model with data-parallelism and ZeRO
ref_model, ref_optim, dist_model, dist_optim = setup(args)
for step in range(args.steps):
# Synthetic data
x = torch.randn(args.batch, args.dim).cuda()
dy = torch.randn_like(x).cuda()
# Reference implementation
ref_optim.zero_grad()
x_ref = x.detach().clone().requires_grad_(True)
y_ref = ref_model(x_ref)
y_ref.backward(dy)
ref_optim.step()
y_ref = ref_model(x_ref).half()
# Distributed implementation
dist_optim.zero_grad()
x_dist = x.detach().clone().requires_grad_(True)
y_dist = dist_model(x_dist)
if get_rank() == 0:
print(f'[{i}] Checking output')
#print("y_ref:", y_ref.flatten()[:10])
#print("y_dist:", y_dist.flatten()[:10])
assert(torch.allclose(y_ref, y_dist, **tol_args))
dy = torch.randn_like(y_ref)
y_ref.backward(dy)
y_dist.backward(dy)
dist_optim.step()
if get_rank() == 0:
print(f'[{i}] Checking gradients')
torch.distributed.barrier()
torch.cuda.synchronize()
assert(torch.allclose(x_ref.grad, x_dist.grad, **tol_args))
# gradient all-reduce within distributed optimizer
dist_opt.complete_reductions()
if get_rank() == 0:
print(f'[{i}] Stepping')
ref_opt.step()
dist_opt.step()
# Check values
torch.cuda.synchronize()
torch.distributed.barrier()
print('Checking new weights')
if get_rank() == 0:
print("ref param:", ref_model.module.linear[0].weight)
print("dist param:", dist_model.linear[0].weight)
for i, (rp, dp) in enumerate(zip(ref_model.parameters(), dist_model.parameters())):
if not torch.allclose(rp, dp, **tol_args):
if get_rank() == 0:
print(f'Rank: {get_rank()}, Param: {i}')
print(f'ref: {rp.sum().item()}, dist: {dp.sum().item()}')
print(rp)
print(dp)
print(torch.abs(rp-dp) > tol_args['atol'])
sys.exit(0)
# zero grads
for rp, dp in zip(ref_model.parameters(), dist_model.parameters()):
rp.grad = None
dp.grad = None
assert_allclose(
y_ref,
y_dist,
f'inconsistent output in step {step}',
)
assert_allclose(
x_ref.grad,
x_dist.grad,
f'inconsistent input grad in step {step}',
)
for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(),
dist_model.parameters())):
assert_allclose(
ref_param,
dist_param,
f'inconsistent param {i} in step {step}',
)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment