"...text-generation-inference.git" did not exist on "e943a294bca239e26828732dd6ab5b6f95dadd0a"
Commit e2af089c authored by Tim Moon's avatar Tim Moon
Browse files

Update dist Adam test to use updated API

parent 6e412916
import argparse import argparse
import os
import random import random
import sys
import torch import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from apex import amp
from apex.optimizers import FusedAdam
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
def __init__(self, args): def __init__(self, args):
super(TestModel, self).__init__() super(TestModel, self).__init__()
self.linear = torch.nn.Sequential(*[
self.linear = torch.nn.Sequential(*[torch.nn.Linear(args.dim, args.dim, bias=args.bias) for _ in range(args.layers)]) torch.nn.Linear(args.dim, args.dim)
for _ in range(args.layers)
])
def forward(self, x): def forward(self, x):
return self.linear(x) y = 0
for l in self.linear:
y += l(x)
return y
def setup(args): def setup(args):
## Model
ref_model = TestModel(args).cuda()
dist_model = TestModel(args).cuda()
# Same weights # Construct models with same parameters
ref_model = TestModel(args).float().cuda()
dist_model = TestModel(args).float().cuda()
with torch.no_grad(): with torch.no_grad():
for dp, rp in zip(dist_model.parameters(), ref_model.parameters()): for ref_param, dist_param in zip(dist_model.parameters(),
dp.data.copy_(rp.data) ref_model.parameters()):
dist_param.data.copy_(ref_param.data)
dist_model = dist_model.half() ref_model = torch.nn.parallel.DistributedDataParallel(
ref_model,
device_ids=[args.rank],
## Optimizer output_device=args.rank,
# same hyperparameters )
ref_opt_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
ref_opt = FusedAdam(ref_model.parameters(), **ref_opt_args) # Construct optimizers with same hyperparameters
optim_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
dist_opt_args = ref_opt_args.copy() ref_optim = torch.optim.Adam(
dist_opt_args.update( {'overlap_reductions' : False} ) [
dist_opt_args.update( {'process_group_size' : args.n_gpu} ) {'params': list(ref_model.parameters())[1::2], 'lr': 5e-3},
dist_opt_args.update( {'dwu_group_size' : args.dwu_group_size} ) {'params': list(ref_model.parameters())[0::2]},
dist_opt_args.update( {'dwu_num_blocks' : 1} ) ],
dist_opt_args.update( {'dwu_num_chunks' : 1} ) **optim_args,
dist_opt = DistributedFusedAdam(dist_model.parameters(), **dist_opt_args) )
dist_opt.set_global_scale(1.) dist_optim = DistributedFusedAdam(
[
## amp-init {'params': list(dist_model.parameters())[1::2], 'lr': 5e-3},
amp_args = { 'loss_scale' : 'dynamic' , 'opt_level' : 'O2'} {'params': list(dist_model.parameters())[0::2]},
ref_model, ref_opt = amp.initialize(ref_model, ref_opt, **amp_args) ],
bucket_cap_mb=71/(4*1024*1024),
**optim_args,
## DDP )
ref_model = DDP(ref_model, device_ids=[args.rank])
with torch.no_grad(): return ref_model, ref_optim, dist_model, dist_optim
for dp in dist_model.parameters():
torch.distributed.broadcast(dp.data, src=0)
for rp in ref_model.parameters():
torch.distributed.broadcast(rp.data, src=0)
torch.cuda.synchronize()
torch.distributed.barrier()
if get_rank() == 0:
print(f'dist opt with {args.n_gpu} GPUs')
return ref_model, ref_opt, dist_model, dist_opt
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1) parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--steps', type=int, default=20) parser.add_argument('--steps', type=int, default=11)
parser.add_argument('--batch', type=int, default=32) parser.add_argument('--batch', type=int, default=5)
parser.add_argument('--dim', type=int, default=4) parser.add_argument('--dim', type=int, default=7)
parser.add_argument('--layers', type=int, default=2) parser.add_argument('--layers', type=int, default=11)
parser.add_argument('--bias', action='store_true')
parser.add_argument('--atol', type=float, default=1e-3) parser.add_argument('--atol', type=float, default=1e-3)
parser.add_argument('--rtol', type=float, default=1) parser.add_argument('--rtol', type=float, default=1e-3)
parser.add_argument('--dwu_group_size', type=float, default=1)
args = parser.parse_args() args = parser.parse_args()
return args return args
def setup_env(args): def setup_env(args):
torch.cuda.set_device(args.local_rank)
# Initialize NCCL
local_rank = args.local_rank
if local_rank < 0:
local_rank = int(os.getenv('LOCAL_RANK', 0))
torch.cuda.set_device(local_rank % torch.cuda.device_count())
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.rank = torch.distributed.get_rank() args.rank = torch.distributed.get_rank()
args.n_gpu = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
seed = 42 + get_rank()
# Initialize RNG
seed = 42 + args.rank
random.seed(seed) random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
return args return args
def get_rank():
return torch.distributed.get_rank()
def main(): def main():
args = parse_args() args = parse_args()
args = setup_env(args) args = setup_env(args)
tol_args = { 'atol' : args.atol, 'rtol' : args.rtol }
torch.set_printoptions(precision=16) torch.set_printoptions(precision=16)
ref_model, ref_opt, dist_model, dist_opt = setup(args) def assert_allclose(ref_x, dist_x, message):
message = (
# lazy_init not called yet, initialize stash f'Rank {args.rank}: {message}\n'
stash = ref_opt._amp_stash f'Reference Adam: {ref_x}\n'
stash.all_fp16_params, stash.all_fp32_from_fp16_params = [], [] f'Distributed Adam: {dist_x}\n'
f'Relative error: {torch.abs((ref_x-dist_x)/ref_x)}\n'
# make sure everything from _first_step_init_ is ready before training )
# e.g. registering allreduce_hook assert torch.allclose(ref_x, dist_x, atol=args.atol, rtol=args.rtol), message
# so that gradients are copied/reduced when necessary
dist_opt._init_everything() # Train model with data-parallelism and ZeRO
ref_model, ref_optim, dist_model, dist_optim = setup(args)
for i in range(args.steps): for step in range(args.steps):
x_ref = torch.randn(args.batch, args.dim, dtype=torch.half).cuda().requires_grad_(True)
x_dist = x_ref.clone().detach().requires_grad_(True) # Synthetic data
x = torch.randn(args.batch, args.dim).cuda()
if get_rank() == 0: dy = torch.randn_like(x).cuda()
print(f'[{i}] Checking input')
#print("x_ref:", x_ref.flatten()[:10]) # Reference implementation
#print("x_dist:", x_dist.flatten()[:10]) ref_optim.zero_grad()
assert(torch.allclose(x_ref, x_dist, **tol_args)) x_ref = x.detach().clone().requires_grad_(True)
y_ref = ref_model(x_ref)
y_ref.backward(dy)
ref_optim.step()
y_ref = ref_model(x_ref).half() # Distributed implementation
dist_optim.zero_grad()
x_dist = x.detach().clone().requires_grad_(True)
y_dist = dist_model(x_dist) y_dist = dist_model(x_dist)
if get_rank() == 0:
print(f'[{i}] Checking output')
#print("y_ref:", y_ref.flatten()[:10])
#print("y_dist:", y_dist.flatten()[:10])
assert(torch.allclose(y_ref, y_dist, **tol_args))
dy = torch.randn_like(y_ref)
y_ref.backward(dy)
y_dist.backward(dy) y_dist.backward(dy)
dist_optim.step()
if get_rank() == 0: # Check values
print(f'[{i}] Checking gradients')
torch.distributed.barrier()
torch.cuda.synchronize()
assert(torch.allclose(x_ref.grad, x_dist.grad, **tol_args))
# gradient all-reduce within distributed optimizer
dist_opt.complete_reductions()
if get_rank() == 0:
print(f'[{i}] Stepping')
ref_opt.step()
dist_opt.step()
torch.cuda.synchronize() torch.cuda.synchronize()
torch.distributed.barrier() torch.distributed.barrier()
print('Checking new weights') assert_allclose(
if get_rank() == 0: y_ref,
print("ref param:", ref_model.module.linear[0].weight) y_dist,
print("dist param:", dist_model.linear[0].weight) f'inconsistent output in step {step}',
)
for i, (rp, dp) in enumerate(zip(ref_model.parameters(), dist_model.parameters())): assert_allclose(
if not torch.allclose(rp, dp, **tol_args): x_ref.grad,
if get_rank() == 0: x_dist.grad,
print(f'Rank: {get_rank()}, Param: {i}') f'inconsistent input grad in step {step}',
print(f'ref: {rp.sum().item()}, dist: {dp.sum().item()}') )
print(rp) for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(),
print(dp) dist_model.parameters())):
assert_allclose(
print(torch.abs(rp-dp) > tol_args['atol']) ref_param,
sys.exit(0) dist_param,
f'inconsistent param {i} in step {step}',
# zero grads )
for rp, dp in zip(ref_model.parameters(), dist_model.parameters()):
rp.grad = None
dp.grad = None
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment