"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fef8d2f726d99c8c0df4c60e283971867a295b59"
Unverified Commit 5ffb22d0 authored by Thor Johnsen's avatar Thor Johnsen Committed by GitHub
Browse files

Merge pull request #1401 from timmoon10/dist-adam-zero

ZeRO-2 support in DistributedFusedAdam
parents 265b451d 846f7f8a
This diff is collapsed.
import math
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdamV3(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be
reduced during first fp16 gradient reduction block.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, e5m2_allgather=False,
do_not_flatten_model=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
if use_mt:
raise RuntimeError('DistributedFusedAdam does not support use_mt.')
if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(DistributedFusedAdamV3, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._predivide = predivide
self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline
self._L2_grad_norm = None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
p_offset = 0
p_i = 0
self._param_state = None
self._model_params = []
self._grads_info = []
self._grad_accs = []
for group in self.param_groups:
self._param_group = group
prev = None
for p in group['params']:
torch.distributed.broadcast(p,0)
if not p.requires_grad:
continue
self._model_params.append(p)
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if self._param_state is None:
self._param_state = state
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt
self._grads = []
self._current_block = self._num_blocks
self._net_total_param_size = p_offset
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._num_blocks
self._shard_size = self._total_param_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._flat_params = torch.zeros_like(self._flat_grads)
def _flat_split(flat):
def __flat_blockify(flat):
return [flat[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __flat_shardify(flat):
return [flat[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
return __flat_blockify(flat), __flat_shardify(flat)
self._flat_grads_blocks, self._flat_grads_shards = _flat_split(self._flat_grads)
self._flat_params_blocks, self._flat_params_shards = _flat_split(self._flat_params)
# master params
self._fp32_p = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')
# copy model params to flat_params and set_ model params to flat_params.
self._individual_flat_grads = []
with torch.no_grad():
for p, grads_info in zip(self._model_params, self._grads_info):
start = grads_info["param_offset"]
end = start + grads_info["param_grads_size"]
flat_p = self._flat_params[start:end].view_as(p)
flat_p.copy_(p)
p.set_(flat_p)
flat_grad = self._flat_grads[start:end]
self._individual_flat_grads.append(flat_grad)
self._fp32_p.copy_(self._flat_params_shards[self._rank_in_group].float())
self._dwu_st = torch.cuda.Stream()
self._l2_grad_norm_st = torch.cuda.Stream()
for group_i in range(self._num_groups):
ranks = [group_i*self._group_size+local_rank for local_rank in range(self._group_size)]
pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg = pg
torch.distributed.all_reduce(self._overflow_buf, group=self._ag_pg)
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
@property
def has_overflow(self):
return True if not self.L2_grad_norm is None and not math.isfinite(self.L2_grad_norm) else False
def set_last_step(self, last_step):
self._last_step = last_step
def _get_flush_block(self):
flush_block = []
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
return flush_block
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.reversible_adam(
p, p_copy, m, v, g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def _flatten_grad_mt(self, scale):
if self._flat_mt and len(self._grads) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads)),
scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions
if self._flat_mt:
self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])
self._grads_generated[param_i]=True
if not self._last_step and self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._dwu_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._dwu_st):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
torch.distributed.all_reduce(self._flat_grads_blocks[block_id])
if block_id == 0:
self._l2_grad_norm_st.wait_stream(self._dwu_st)
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, flat_grad in enumerate(self._individual_flat_grads):
if not self._grads_generated[param_i]:
flat_grad.zero_()
self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
self._dwu_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._dwu_st):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
torch.distributed.all_reduce(self._flat_grads)
self._l2_grad_norm_st.wait_stream(self._dwu_st)
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def step(self, closure=None, skip_overflow_check=False):
loss = None
if closure is not None:
loss = closure()
with torch.cuda.stream(self._dwu_st):
self.__launch_step_kernel(
self._fp32_p,
self._flat_params_shards[self._rank_in_group],
self._fp32_m,
self._fp32_v,
self._flat_grads_shards[self._rank_in_group])
torch.distributed.all_gather(self._flat_params_shards, self._flat_params_shards[self._rank_in_group], group=self._ag_pg, no_copy=True)
for p in self._model_params: self.state[p]['step'] += 1
torch.cuda.current_stream().wait_stream(self._dwu_st)
return loss
import argparse import argparse
import os
import random import random
import sys
import torch import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from apex import amp
from apex.optimizers import FusedAdam
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
def __init__(self, args): def __init__(self, args):
super(TestModel, self).__init__() super(TestModel, self).__init__()
self.linear = torch.nn.Sequential(*[
self.linear = torch.nn.Sequential(*[torch.nn.Linear(args.dim, args.dim, bias=args.bias) for _ in range(args.layers)]) torch.nn.Linear(args.dim, args.dim)
for _ in range(args.layers)
])
def forward(self, x): def forward(self, x):
return self.linear(x) y = 0
for i, l in enumerate(self.linear):
y += (i+1) * l(x)
return y
def setup(args): def setup(args):
## Model
ref_model = TestModel(args).cuda()
dist_model = TestModel(args).cuda()
# Same weights # Construct models with same parameters
ref_model = TestModel(args).float().cuda()
dist_model = TestModel(args).float().cuda()
with torch.no_grad(): with torch.no_grad():
for dp, rp in zip(dist_model.parameters(), ref_model.parameters()): for ref_param, dist_param in zip(dist_model.parameters(),
dp.data.copy_(rp.data) ref_model.parameters()):
dist_param.data.copy_(ref_param.data)
dist_model = dist_model.half() ref_model = torch.nn.parallel.DistributedDataParallel(
ref_model,
device_ids=[args.rank],
## Optimizer output_device=args.rank,
# same hyperparameters )
ref_opt_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
ref_opt = FusedAdam(ref_model.parameters(), **ref_opt_args) # Construct optimizers with same hyperparameters
optim_args = { 'lr': 1, 'betas': (0.5,0.75), 'eps': 0.1, 'weight_decay': 0.1 }
dist_opt_args = ref_opt_args.copy() ref_optim = torch.optim.AdamW(
dist_opt_args.update( {'overlap_reductions' : False} ) [
dist_opt_args.update( {'process_group_size' : args.n_gpu} ) {'params': list(ref_model.parameters())[1::2], 'lr': 0.5},
dist_opt_args.update( {'dwu_group_size' : args.dwu_group_size} ) {'params': list(ref_model.parameters())[0::2]},
dist_opt_args.update( {'dwu_num_blocks' : 1} ) ],
dist_opt_args.update( {'dwu_num_chunks' : 1} ) **optim_args,
dist_opt = DistributedFusedAdam(dist_model.parameters(), **dist_opt_args) )
dist_opt.set_global_scale(1.) dist_optim = DistributedFusedAdam(
[
## amp-init {'params': list(dist_model.parameters())[1::2], 'lr': 0.5},
amp_args = { 'loss_scale' : 'dynamic' , 'opt_level' : 'O2'} {'params': list(dist_model.parameters())[0::2]},
ref_model, ref_opt = amp.initialize(ref_model, ref_opt, **amp_args) ],
bucket_cap_mb=71/(4*1024*1024),
**optim_args,
## DDP )
ref_model = DDP(ref_model, device_ids=[args.rank])
with torch.no_grad(): return ref_model, ref_optim, dist_model, dist_optim
for dp in dist_model.parameters():
torch.distributed.broadcast(dp.data, src=0)
for rp in ref_model.parameters():
torch.distributed.broadcast(rp.data, src=0)
torch.cuda.synchronize()
torch.distributed.barrier()
if get_rank() == 0:
print(f'dist opt with {args.n_gpu} GPUs')
return ref_model, ref_opt, dist_model, dist_opt
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1) parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--steps', type=int, default=20) parser.add_argument('--steps', type=int, default=3)
parser.add_argument('--batch', type=int, default=32) parser.add_argument('--batch', type=int, default=5)
parser.add_argument('--dim', type=int, default=4) parser.add_argument('--dim', type=int, default=7)
parser.add_argument('--layers', type=int, default=2) parser.add_argument('--layers', type=int, default=11)
parser.add_argument('--bias', action='store_true') parser.add_argument('--atol', type=float, default=1e-5)
parser.add_argument('--atol', type=float, default=1e-3) parser.add_argument('--rtol', type=float, default=1e-5)
parser.add_argument('--rtol', type=float, default=1)
parser.add_argument('--dwu_group_size', type=float, default=1)
args = parser.parse_args() args = parser.parse_args()
return args return args
def setup_env(args): def setup_env(args):
torch.cuda.set_device(args.local_rank)
# Initialize NCCL
local_rank = args.local_rank
if local_rank < 0:
local_rank = int(os.getenv('LOCAL_RANK', 0))
torch.cuda.set_device(local_rank % torch.cuda.device_count())
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.rank = torch.distributed.get_rank() args.rank = torch.distributed.get_rank()
args.n_gpu = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
seed = 42 + get_rank()
# Initialize RNG
seed = 42 + args.rank
random.seed(seed) random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
return args return args
def get_rank():
return torch.distributed.get_rank()
def main(): def main():
args = parse_args() args = parse_args()
args = setup_env(args) args = setup_env(args)
tol_args = { 'atol' : args.atol, 'rtol' : args.rtol }
torch.set_printoptions(precision=16) torch.set_printoptions(precision=16)
ref_model, ref_opt, dist_model, dist_opt = setup(args) def assert_allclose(ref_x, dist_x, message):
message = (
# lazy_init not called yet, initialize stash f'Rank {args.rank}: {message}\n'
stash = ref_opt._amp_stash f'Reference Adam: {ref_x}\n'
stash.all_fp16_params, stash.all_fp32_from_fp16_params = [], [] f'Distributed Adam: {dist_x}\n'
f'Relative error: {torch.abs((ref_x-dist_x)/ref_x)}\n'
# make sure everything from _first_step_init_ is ready before training )
# e.g. registering allreduce_hook assert torch.allclose(ref_x, dist_x, atol=args.atol, rtol=args.rtol), message
# so that gradients are copied/reduced when necessary
dist_opt._init_everything() # Train model with data-parallelism and ZeRO
ref_model, ref_optim, dist_model, dist_optim = setup(args)
for i in range(args.steps): for step in range(args.steps):
x_ref = torch.randn(args.batch, args.dim, dtype=torch.half).cuda().requires_grad_(True)
x_dist = x_ref.clone().detach().requires_grad_(True) # Synthetic data
x = torch.randn(args.batch, args.dim).cuda()
if get_rank() == 0: dy = torch.randn_like(x).cuda()
print(f'[{i}] Checking input')
#print("x_ref:", x_ref.flatten()[:10]) # Reference implementation
#print("x_dist:", x_dist.flatten()[:10]) ref_optim.zero_grad()
assert(torch.allclose(x_ref, x_dist, **tol_args)) x_ref = x.detach().clone().requires_grad_(True)
y_ref = ref_model(x_ref)
y_ref.backward(dy)
ref_optim.step()
y_ref = ref_model(x_ref).half() # Distributed implementation
dist_optim.zero_grad()
x_dist = x.detach().clone().requires_grad_(True)
y_dist = dist_model(x_dist) y_dist = dist_model(x_dist)
if get_rank() == 0:
print(f'[{i}] Checking output')
#print("y_ref:", y_ref.flatten()[:10])
#print("y_dist:", y_dist.flatten()[:10])
assert(torch.allclose(y_ref, y_dist, **tol_args))
dy = torch.randn_like(y_ref)
y_ref.backward(dy)
y_dist.backward(dy) y_dist.backward(dy)
dist_optim.step()
if get_rank() == 0: # Check values
print(f'[{i}] Checking gradients')
torch.distributed.barrier()
torch.cuda.synchronize()
assert(torch.allclose(x_ref.grad, x_dist.grad, **tol_args))
# gradient all-reduce within distributed optimizer
dist_opt.complete_reductions()
if get_rank() == 0:
print(f'[{i}] Stepping')
ref_opt.step()
dist_opt.step()
torch.cuda.synchronize() torch.cuda.synchronize()
torch.distributed.barrier() torch.distributed.barrier()
print('Checking new weights') assert_allclose(
if get_rank() == 0: y_ref,
print("ref param:", ref_model.module.linear[0].weight) y_dist,
print("dist param:", dist_model.linear[0].weight) f'inconsistent output in step {step}',
)
for i, (rp, dp) in enumerate(zip(ref_model.parameters(), dist_model.parameters())): assert_allclose(
if not torch.allclose(rp, dp, **tol_args): x_ref.grad,
if get_rank() == 0: x_dist.grad,
print(f'Rank: {get_rank()}, Param: {i}') f'inconsistent input grad in step {step}',
print(f'ref: {rp.sum().item()}, dist: {dp.sum().item()}') )
print(rp) for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(),
print(dp) dist_model.parameters())):
assert_allclose(
print(torch.abs(rp-dp) > tol_args['atol']) ref_param,
sys.exit(0) dist_param,
f'inconsistent param {i} in step {step}',
# zero grads )
for rp, dp in zip(ref_model.parameters(), dist_model.parameters()):
rp.grad = None
dp.grad = None
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment