Commit 44f54712 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Reduce CPU overhead, bigger step, all-gather

parent f0448054
......@@ -68,6 +68,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._overflow_buf = torch.cuda.IntTensor([0])
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
......@@ -94,12 +96,21 @@ class DistributedFusedAdam(torch.optim.Optimizer):
p_offset = 0
p_i = 0
self._param_state = None
self._model_params = []
self._grads_info = []
for group in self.param_groups:
self._param_group = group
for p in group['params']:
torch.distributed.broadcast(p,0)
if not p.requires_grad:
continue
self._model_params.append(p)
state = self.state['p']
if len(state) == 0:
state['step'] = 0
if self._param_state is None:
self._param_state = state
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
def allreduce_hook(grad):
......@@ -134,12 +145,85 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size]).half().cuda()
self._new_params = None
self._fp32_p = None
self._fp32_m = None
self._fp32_v = None
self._copy_to_fp32 = False
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
def _flat_split(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
def _full_packed_split(p):
def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
def __blockify(p):
return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
list_of_mega_shards = __shardify(p)
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
def _packed_split(p):
def __packed_blockify(p):
packed_block_size = self._num_chunks*self._shard_size
return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]
def __packed_chunkify(p):
# in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size
return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self._packed_flat_to_model_params = []
for shard_id in range(self._group_size):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size
flat_shard_end = flat_shard_start + self._shard_size
for p, grads_info in zip(self._model_params, self._grads_info):
flat_grad_start = grads_info["param_offset"]
flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)
if clipped_start < clipped_end:
grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group:
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
master_param_fragment.copy_(model_param_fragment)
p_in, p_out = zip(*self._packed_flat_to_model_params)
self._packed_flat_to_model_params = [p_in, p_out]
self._distributed_weight_update = distributed_weight_update # Is this still needed?
self._num_rs_pg = dwu_num_rs_pg
......@@ -221,64 +305,91 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
start = block_id * self._block_size
end = start + self._block_size
grad_block = self._flat_grads[start:end]
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk in range(self._num_chunks):
glob_chunk = block_id * self._num_chunks + chunk
grad_chunk = grad_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size]
grad_shards = [grad_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
rs_stream = self._rs_st[glob_chunk%self._num_rs_pg]
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[glob_chunk%self._num_rs_pg],async_op=True,no_copy=True)
if self._num_groups > 1:
ar_stream = self._ar_st[glob_chunk%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
work.wait()
work = torch.distributed.all_reduce(grad_shards[self._rank_in_group],group=self._ar_pg[glob_chunk%self._num_ar_pg],async_op=True)
works[chunk] = work
if self._compute_L2_grad_norm:
for chunk in range(self._num_chunks):
grad_chunk = grad_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size]
grad_shards = [grad_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
with torch.cuda.stream(self._l2_grad_norm_st):
works[chunk].wait()
l2_grad_sq = grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
if block_id+1 == self._num_blocks and chunk == 0:
self._L2_grad_norm = l2_grad_sq
else:
self._L2_grad_norm += l2_grad_sq
if block_id == 0 and chunk+1 == self._num_chunks:
torch.distributed.all_reduce(self._L2_grad_norm,group=self._l2_grad_norm_pg)
self._L2_grad_norm.sqrt_()
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _pipeline_block_step(self, block_id):
if self._new_params is None:
self._new_params = torch.zeros_like(self._flat_grads,dtype=torch.uint8 if self._e5m2_allgather else self._flat_grads.dtype)
start = block_id * self._block_size
end = start + self._block_size
new_params_block = self._new_params[start:end]
# Optionally compute L2 grad norm
if self._compute_L2_grad_norm and block_id == 0:
with torch.cuda.stream(self._l2_grad_norm_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
self._L2_grad_norm = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(self._L2_grad_norm,group=self._l2_grad_norm_pg)
self._L2_grad_norm.sqrt_()
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.adam(
p, p_copy, m, v, g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
works = [None]*self._num_chunks
for chunk in range(self._num_chunks):
glob_chunk = block_id * self._num_chunks + chunk
new_params_chunk = new_params_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size]
new_params_shards = [new_params_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
ag_stream = self._ag_st[glob_chunk%self._num_ag_pg]
with torch.cuda.stream(ag_stream):
self._reductions_works[block_id][chunk].wait()
self._partial_step_single_shard(block_id,chunk)
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[glob_chunk%self._num_ag_pg],async_op=True,no_copy=True)
works[chunk] = work
self._allgather_works[block_id] = works
def _pipeline_block_step(self, block_id):
# Call step kernel once per block
ag_stream = self._ag_st[block_id%self._num_ag_pg]
with torch.cuda.stream(ag_stream):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p_blocks[block_id],
self._fp16_p_blocks[block_id],
self._fp32_m_blocks[block_id],
self._fp32_v_blocks[block_id],
self._fp16_g_blocks[block_id])
# Call all-gather once per step.
# FIXME: Determine which is faster, one all-gather per block or a single all-gather at end
if block_id == 0:
for other_ag_stream in self._ag_st:
self._completion_st.wait_stream(other_ag_stream)
with torch.cuda.stream(self._completion_st):
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _pipeline_step(self):
# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p,
self._fp16_p,
self._fp32_m,
self._fp32_v,
self._fp16_g)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale):
if self._flat_mt:
......@@ -360,145 +471,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
else:
return None
# Distributed weight update algorithm:
# Model parameters are kept as-is.
# Gradients are flattened during backprop.
# Reductions are done with an intra-node reduce-scatter followed by an inter-node all-reduce.
# Step function is sharded and the shards are assembled with an intra-node all-gather.
# Sharded step function needs internal fp32 buffers for p, m and v.
# To save memory, we allocate the fp32 buffers to cover only the shards local GPU will update.
# This means we have to play around with indexes, which requires knowledge of block and shard number.
# Implement a method that performs a partial update of a single shard within a single block.
def _partial_step_single_shard(self, block_id, chunk_id, undo=False):
"""Perform step function for a single shard.
Arguments:
block_id (integer): Block index of shard [0,self._num_blocks>
undo (boolean, optional): If True, undo effect of previously called partial step.
"""
shard_id = self._rank_in_group
shard_start = (block_id * self._num_chunks + chunk_id) * self._chunk_size + shard_id * self._shard_size
shard_end = shard_start + self._shard_size
if self._fp32_p is None:
assert (not undo), "Tried to undo step before calling step."
# Allocate fp32 buffers on demand. Note that we don't make these part of the state
# since each rank only has partial buffers.
# To-Do:
self._fp32_p = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._fp32_m = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._fp32_v = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
if self._revert_method > 1:
self._fp32_backup_p = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._fp32_backup_m = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._fp32_backup_v = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._copy_to_fp32 = True
step = None
param_i = 0
for group in self.param_groups:
# compute combined scale factor for this group
combined_scale = self._global_scale
if group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if group['bias_correction'] else 0
group_start = -1
group_end = -2
for p in group['params']:
if not p.requires_grad:
continue
#if p.grad.is_sparse:
# raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if step is None:
# all we want from state at this point is state['step'], which should be the same for all p
step = state['step']
nels = p.numel()
offset = self._grads_info[param_i]['param_offset']
param_i += 1
start = offset
end = start + nels
clipped_start = start if start >= shard_start else shard_start
clipped_end = end if end <= shard_end else shard_end
# check if this parameter contributes to shard
if clipped_start < clipped_end:
if group_start < 0:
group_start = clipped_start
group_end = clipped_end
if self._copy_to_fp32:
param_offset = clipped_start - shard_start
param_size = clipped_end - clipped_start
buffer_start = (block_id * self._num_chunks + chunk_id) * self._shard_size + param_offset
buffer_end = buffer_start + param_size
param_start = (clipped_start - start)
param_end = param_start + param_size
self._fp32_p[buffer_start:buffer_end].copy_(p.view(-1)[param_start:param_end].float())
group_size = group_end - group_start
if group_size > 0:
assert (step is not None), "state['step'] is None for this parameter group"
group_offset = group_start - shard_start
group_shard_start = shard_start + group_offset
group_shard_end = group_shard_start + group_size
group_buffer_start = (block_id * self._num_chunks + chunk_id) * self._shard_size + group_offset
group_buffer_end = group_buffer_start + group_size
beta1, beta2 = group['betas']
if undo:
if self._revert_method == 1:
fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p[group_buffer_start:group_buffer_end],
self._fp32_m[group_buffer_start:group_buffer_end],
self._fp32_v[group_buffer_start:group_buffer_end],
self._flat_grads[group_shard_start:group_shard_end],
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step+1, # FIXME: Verify this should be step+1
self.eps_mode,
bias_correction,
group['weight_decay'])
elif self._revert_method == 2:
self._fp32_p[group_buffer_start:group_buffer_end].copy_(self._fp32_backup_p[group_buffer_start:group_buffer_end])
self._fp32_m[group_buffer_start:group_buffer_end].copy_(self._fp32_backup_m[group_buffer_start:group_buffer_end])
self._fp32_v[group_buffer_start:group_buffer_end].copy_(self._fp32_backup_v[group_buffer_start:group_buffer_end])
elif self._revert_method == 3:
raise RuntimeError('revert_step debug option not implemented yet')
else:
if self._revert_method > 1:
self._fp32_backup_p[group_buffer_start:group_buffer_end].copy_(self._fp32_p[group_buffer_start:group_buffer_end])
self._fp32_backup_m[group_buffer_start:group_buffer_end].copy_(self._fp32_m[group_buffer_start:group_buffer_end])
self._fp32_backup_v[group_buffer_start:group_buffer_end].copy_(self._fp32_v[group_buffer_start:group_buffer_end])
fused_adam_cuda.adam(
self._fp32_p[group_buffer_start:group_buffer_end],
self._new_params[group_shard_start:group_shard_end],
self._fp32_m[group_buffer_start:group_buffer_end],
self._fp32_v[group_buffer_start:group_buffer_end],
self._flat_grads[group_shard_start:group_shard_end],
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step+1,
self.eps_mode,
bias_correction,
group['weight_decay'])
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
......@@ -521,17 +493,34 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
self._copy_to_fp32 = False
self._decomp_stats = None
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def revert_step(self):
"""Revert effect of previously calling partial_step.
"""
for block_id in range(self._num_blocks):
for chunk in range(self._num_chunks):
self._partial_step_single_shard(block_id, chunk, undo=True)
# Call undo kernel once per step
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p,
self._fp32_m,
self._fp32_v,
self._fp16_g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def step(self, closure=None, skip_overflow_check=False):
loss = None
......@@ -539,14 +528,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
loss = closure()
if self._last_step or not self._overlap_reductions or not self._full_pipeline:
for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_step(block_id)
self._pipeline_step()
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks-1,-1,-1):
for chunk in range(self._num_chunks):
self._allgather_works[block_id][chunk].wait()
# Check for overflow
# Store state for loss scaler calculation
if skip_overflow_check:
......@@ -559,40 +543,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.revert_step()
else:
# Copy self._new_params to model params
if self._e5m2_allgather or self._do_not_flatten_model:
p_in = []
p_out = []
with torch.no_grad():
param_i = 0
for group in self.param_groups:
for p in group['params']:
if not p.requires_grad:
continue
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['step'] += 1
nels = p.numel()
offset = self._grads_info[param_i]['param_offset']
if self._e5m2_allgather or self._do_not_flatten_model:
p_in.append(self._new_params[offset:offset+nels].view_as(p))
p_out.append(p)
else:
p.set_(self._new_params[offset:offset+nels].view_as(p))
param_i += 1
if self._e5m2_allgather:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
[p_in, p_out]);
elif self._do_not_flatten_model:
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
[p_in, p_out],
1.0);
if not self._e5m2_allgather and not self._do_not_flatten_model:
self._new_params = None
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params)
torch.cuda.current_stream().wait_stream(self._completion_st)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment