"tests/python/mxnet/test_nn.py" did not exist on "ca302a13ac4aac6a678794c4dc21f8f775728937"
Commit f0448054 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add alternate distributed optimizer implementation

parent 4a01ff26
import math
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdamV2(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be
reduced during first fp16 gradient reduction block.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, e5m2_allgather=False,
do_not_flatten_model=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
if use_mt:
raise RuntimeError('DistributedFusedAdam does not support use_mt.')
if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(DistributedFusedAdamV2, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
self._predivide = predivide
self._overlap_reductions = overlap_reductions
self._full_pipeline = full_pipeline
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._group_id = torch.distributed.get_rank() // self._group_size
self._num_groups = torch.distributed.get_world_size() // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
self._rank = torch.distributed.get_rank()
self._rank_in_group = self._rank % self._group_size
self._world_size = torch.distributed.get_world_size()
p_offset = 0
p_i = 0
self._grads_info = []
for group in self.param_groups:
for p in group['params']:
torch.distributed.broadcast(p,0)
if not p.requires_grad:
continue
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
def allreduce_hook(grad):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, grad)
param.register_hook(allreduce_hook)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
# enforce 128b alignment (64 * fp16)
p_offset = ((p_offset + 63) // 64) * 64
p_i += 1
self._grads_generated = [False]*len(self._grads_info)
self._grads = [None]*len(self._grads_info)
self._current_block = self._group_size
self._net_total_param_size = p_offset
self._total_param_size = p_offset
min_page_size = 256 * self._group_size
self._total_param_size = ((self._total_param_size + min_page_size - 1) // min_page_size) * min_page_size
self._block_size = self._total_param_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, min_page_size=%d, self._block_size=%d" % (self._net_total_param_size, self._total_param_size,min_page_size,self._block_size))
self._low_param_i = [0]*self._group_size
for block_id in range(self._group_size-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._global_scale = 1.0
self._fp32_p = None
self._new_params = torch.zeros(size=[self._total_param_size], dtype=torch.uint8).cuda()
self._flat_grads = torch.zeros(size=[self._total_param_size], dtype=torch.float16).cuda()
if self._num_groups > 1:
self._num_ar_pg = dwu_num_ar_pg
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
self._num_rs_pg = dwu_num_rs_pg
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
self._redux_st = [torch.cuda.Stream() for _ in range(self._group_size)]
self._compute_L2_grad_norm = compute_L2_grad_norm
if self._compute_L2_grad_norm:
self._L2_grad_norm = torch.zeros(size=[1],dtype=torch.float32).cuda()
self._l2_grad_norm_st = torch.cuda.Stream()
self._completion_st = torch.cuda.Stream()
self._last_step = False
def set_last_step(self, last_step):
self._last_step = last_step
def _get_flush_block(self):
flush_block = []
if self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
if self._current_block == 0:
# reset
self._grads_generated = [False]*len(self._grads_info)
return flush_block
def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
start = block_id * self._block_size
end = start + self._block_size
grad_block = self._flat_grads[start:end]
active_rank = self._group_id*self._group_size+block_id
redux_stream = self._redux_st[block_id]
redux_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(redux_stream):
work = torch.distributed.reduce(grad_block,active_rank,group=self._rs_pg[block_id%self._num_rs_pg],async_op=True)
if self._num_groups > 1 and self._rank == active_rank:
work.wait()
work = torch.distributed.all_reduce(grad_block,group=self._ar_pg[block_id%self._num_ar_pg],async_op=True)
if self._compute_L2_grad_norm:
if self._rank == active_rank:
with torch.cuda.stream(self._l2_grad_norm_st):
work.wait()
self._L2_grad_norm = grad_block.norm(dtype=torch.float32,p=2)**2
if block_id == 0:
with torch.cuda.stream(self._l2_grad_norm_st):
torch.distributed.all_reduce(self._L2_grad_norm,group=self._rs_pg[self._num_rs_pg-1])
self._L2_grad_norm.sqrt_()
# FIXME: Does completion stream need to wait for L2 grad norm to finish?
self._completion_st.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(redux_stream):
work.wait()
def _pipeline_block_step(self, block_id):
active_rank = self._group_id*self._group_size+block_id
if self._rank == active_rank:
redux_stream = self._redux_st[block_id]
with torch.cuda.stream(redux_stream):
self._partial_step_single_shard(block_id)
if block_id == 0:
new_params_blocks = [self._new_params[block*self._block_size:(block+1)*self._block_size] for block in range(self._group_size)]
for redux_stream in self._redux_st:
self._completion_st.wait_stream(redux_stream)
with torch.cuda.stream(self._completion_st):
torch.distributed.all_gather(new_params_blocks,new_params_blocks[self._rank_in_group],group=self._rs_pg[self._num_rs_pg-1],no_copy=True)
def _flatten_grad_mt(self, scale):
grads = []
flat_grads = []
for p_i, (grads_info, grad) in enumerate(zip(self._grads_info, self._grads)):
if grad is not None:
grads.append(grad)
flat_grads.append( self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]] )
self._grads = [None]*len(self._grads_info)
if len(grads) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
[grads, flat_grads],
scale)
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad):
# handle overlapped reductions
self._grads[param_i] = grad.view(-1)
self._grads_generated[param_i]=True
if not self._last_step:
if self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id)
if self._full_pipeline:
self._pipeline_block_step(block_id)
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def has_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow = self._overflow_buf.item()
self._overflow_buf.zero_()
return has_overflow
@property
def peek_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return self._overflow_buf.item()
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if start >= 0 and start < end:
out_p = output_params[start:end]
else:
out_p = output_params
fused_adam_cuda.strided_check_finite(self._overflow_buf,
out_p,
stride,
1 if clear else 0)
@property
def L2_grad_norm(self):
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
else:
return None
# Distributed weight update algorithm:
# Model parameters are kept as-is.
# Gradients are flattened during backprop.
# Reductions are done with an intra-node reduce-scatter followed by an inter-node all-reduce.
# Step function is sharded and the shards are assembled with an intra-node all-gather.
# Sharded step function needs internal fp32 buffers for p, m and v.
# To save memory, we allocate the fp32 buffers to cover only the shards local GPU will update.
# This means we have to play around with indexes, which requires knowledge of block and shard number.
# Implement a method that performs a partial update of a single shard within a single block.
def _partial_step_single_shard(self, block_id, undo=False):
"""Perform step function for a single shard.
Arguments:
block_id (integer): Block index of shard [0,self._group_size>
undo (boolean, optional): If True, undo effect of previously called partial step.
"""
block_start = block_id * self._block_size
block_end = block_start + self._block_size
if self._fp32_p is None:
assert (not undo), "Tried to undo step before calling step."
# Allocate fp32 buffers on demand. Note that we don't make these part of the state
# since each rank only has partial buffers.
# To-Do:
self._fp32_p = torch.zeros([self._block_size]).float().cuda()
self._fp32_m = torch.zeros([self._block_size]).float().cuda()
self._fp32_v = torch.zeros([self._block_size]).float().cuda()
self._copy_to_fp32 = True
step = None
param_i = 0
for group in self.param_groups:
# compute combined scale factor for this group
combined_scale = self._global_scale
if group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if group['bias_correction'] else 0
group_start = -1
group_end = -2
for p in group['params']:
if not p.requires_grad:
continue
#if p.grad.is_sparse:
# raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if step is None:
# all we want from state at this point is state['step'], which should be the same for all p
step = state['step']
nels = p.numel()
offset = self._grads_info[param_i]['param_offset']
param_i += 1
start = offset
end = start + nels
clipped_start = start if start >= block_start else block_start
clipped_end = end if end <= block_end else block_end
# check if this parameter contributes to block
if clipped_start < clipped_end:
if group_start < 0:
group_start = clipped_start
group_end = clipped_end
if self._copy_to_fp32:
param_offset = clipped_start - block_start
param_size = clipped_end - clipped_start
buffer_start = param_offset
buffer_end = buffer_start + param_size
param_start = (clipped_start - start)
param_end = param_start + param_size
#assert (buffer_start >= 0 and buffer_end <= self._fp32_p.numel() and param_start >= 0 and param_end <= p.numel()), "Illegal copy"
self._fp32_p[buffer_start:buffer_end].copy_(p.view(-1)[param_start:param_end].float())
group_size = group_end - group_start
if group_size > 0:
assert (step is not None), "state['step'] is None for this parameter group"
group_offset = group_start - block_start
group_block_start = block_start + group_offset
group_block_end = group_block_start + group_size
group_buffer_start = group_offset
group_buffer_end = group_buffer_start + group_size
beta1, beta2 = group['betas']
if undo:
fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p[group_buffer_start:group_buffer_end],
self._fp32_m[group_buffer_start:group_buffer_end],
self._fp32_v[group_buffer_start:group_buffer_end],
self._flat_grads[group_block_start:group_block_end],
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step+1, # FIXME: Verify this should be step+1
self.eps_mode,
bias_correction,
group['weight_decay'])
else:
fused_adam_cuda.adam(
self._fp32_p[group_buffer_start:group_buffer_end],
self._new_params[group_block_start:group_block_end],
self._fp32_m[group_buffer_start:group_buffer_end],
self._fp32_v[group_buffer_start:group_buffer_end],
self._flat_grads[group_block_start:group_block_end],
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step+1,
self.eps_mode,
bias_correction,
group['weight_decay'])
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated):
if not grad_generated:
grad_info = self._grads_info[param_i]
param_offset = grad_info["param_offset"]
param_size = grad_info["param_grads_size"]
self._flat_grads[param_offset:param_offset+param_size].zero_()
self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
for block_id in range(self._group_size-1,-1,-1):
self._pipeline_block_reductions(block_id)
self._copy_to_fp32 = False
self._current_block = self._group_size
self._grads_generated = [False]*len(self._grads_info)
def revert_step(self):
"""Revert effect of previously calling partial_step.
"""
self._partial_step_single_shard(self._rank_in_group, undo=True)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
if self._last_step or not self._overlap_reductions or not self._full_pipeline:
for block_id in range(self._group_size-1,-1,-1):
self._pipeline_block_step(block_id)
with torch.cuda.stream(self._completion_st):
# Check for overflow
# Store state for loss scaler calculation
self.strided_check_finite(self._new_params, stride=self._block_size, start=0, end=self._net_total_param_size)
has_overflow = self.peek_overflow
if has_overflow:
print("Reverting step")
self.revert_step()
else:
# Copy self._new_params to model params
p_in = []
p_out = []
with torch.no_grad():
param_i = 0
for group in self.param_groups:
for p in group['params']:
if not p.requires_grad:
continue
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['step'] += 1
nels = p.numel()
offset = self._grads_info[param_i]['param_offset']
p_in.append(self._new_params[offset:offset+nels].view_as(p))
p_out.append(p)
param_i += 1
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
[p_in, p_out]);
torch.cuda.current_stream().wait_stream(self._completion_st)
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment