Commit a60bbe63 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Try out different partition scheme

parent 7da28fc3
...@@ -68,67 +68,234 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -68,67 +68,234 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
self._predivide = predivide assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False
self._overlap_reductions = overlap_reductions self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._predivide = predivide
self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._group_id = torch.distributed.get_rank() // self._group_size
self._num_groups = torch.distributed.get_world_size() // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
self._rank = torch.distributed.get_rank()
self._rank_in_group = self._rank % self._group_size
self._world_size = torch.distributed.get_world_size() self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
p_offset = 0 p_offset = 0
p_i = 0 p_i = 0
self._param_state = None
self._model_params = []
self._grads_info = [] self._grads_info = []
self._grad_accs = []
for group in self.param_groups: for group in self.param_groups:
self._param_group = group
prev = None
for p in group['params']: for p in group['params']:
torch.distributed.broadcast(p,0) torch.distributed.broadcast(p,0)
if not p.requires_grad: if not p.requires_grad:
continue continue
self._model_params.append(p)
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if self._param_state is None:
self._param_state = state
p_grads_size = p.numel() p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset): def wrapper(param, param_i, param_grads_size, param_offset):
def allreduce_hook(grad): param_tmp = param.expand_as(param)
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, grad) grad_acc = param_tmp.grad_fn.next_functions[0][0]
param.register_hook(allreduce_hook) def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset}) self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset) wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size p_offset += p_grads_size
# enforce 128b alignment (64 * fp16) # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
p_offset = ((p_offset + 63) // 64) * 64 # RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1 p_i += 1
self._grads_generated = [False]*len(self._grads_info) self._grads_generated = [False]*len(self._grads_info)
self._grads = [None]*len(self._grads_info) self._flat_mt = flat_mt
self._current_block = self._group_size self._grads = []
if self._overlap_reductions:
self._current_block = self._num_blocks
self._net_total_param_size = p_offset self._net_total_param_size = p_offset
self._total_param_size = p_offset self._total_param_size = p_offset
min_page_size = 256 * self._group_size dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + min_page_size - 1) // min_page_size) * min_page_size self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._group_size self._block_size = self._total_param_size // self._num_blocks
print("self._net_total_param_size=%d, self._total_param_size=%d, min_page_size=%d, self._block_size=%d" % (self._net_total_param_size, self._total_param_size,min_page_size,self._block_size)) self._shard_size = self._block_size // self._group_size
self._chunk_size = self._shard_size // self._num_chunks
self._low_param_i = [0]*self._group_size print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d, self._chunk_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size,self._chunk_size))
for block_id in range(self._group_size-1,-1,-1):
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1 p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size: while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1 p_i -= 1
self._low_param_i[block_id] = p_i self._low_param_i[block_id] = p_i
print(self._low_param_i) print(self._low_param_i)
self._global_scale = 1.0 self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp32_p = None self._mega_shard_size = self._num_blocks * self._num_chunks * self._chunk_size
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._new_params = torch.zeros(size=[self._total_param_size], dtype=torch.uint8).cuda() self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._flat_grads = torch.zeros(size=[self._total_param_size], dtype=torch.float16).cuda() self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
self._individual_flat_grads = []
for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]].view_as(p))
def _flat_split(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._group_size)]
list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_shards = [__shardify(block) for block in list_of_blocks]
list_of_list_of_list_of_chunks = [[__chunkify(shard) for shard in shards] for shards in list_of_list_of_shards]
return list_of_blocks, list_of_list_of_shards, list_of_list_of_list_of_chunks
self._flat_grads_blocks, self._flat_grads_shards, self._flat_grads_chunks = _flat_split(self._flat_grads)
def _full_packed_split(p):
def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
def __blockify(p):
return [p[block_id*self._num_chunks*self._chunk_size:(block_id+1)*self._num_chunks*self._chunk_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
list_of_mega_shards = __shardify(p)
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
def _packed_split(p):
def __packed_blockify(p):
packed_block_size = self._num_chunks*self._chunk_size
return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]
def __packed_chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
# current arrangement
#
# self._flat_grads
# self._flat_grads_blocks [x self._num_blocks, self._block_size]
# self._flat_grads_chunks [x self._num_chunks, self._chunk_size]
# self._flat_grads_shards [x self._group_size, self._shard_size]
#
# self._new_params
# self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._shard_size]
# self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._shard_size]
# self._new_params_mega_chunks [x self._num_chunks, self._shard_size]
#
# self._fp32_p
# self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._shard_size]
# self._fp32_p_chunks [x self._num_chunks, self._shard_size]
# each chunk contains one shard
# same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g
#
# Usage:
#
# for chunk_id in range(self._num_chunks):
# works[chunk_id] = torch.distributed.reduce_scatter(self._flat_grads_chunks[block_id][chunk_id], self._fp16_g_chunks[block_id][chunk_id], ...)
#
# ----------------------------------------------------------------------------------------
#
# new arrangement
#
# NB! New equations for self._shard_size and self._chunk_size
#
# self._flat_grads
# self._flat_grads_blocks [x self._num_blocks, self._block_size]
# self._flat_grads_shards [x self._group_size, self._shard_size]
# self._flat_grads_chunks [x self._num_chunks, self._chunk_size]
#
# self._new_params
# self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._chunk_size]
# self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]
# self._new_params_mega_chunks [x self._num_chunks, self._chunk_size]
#
# self._fp32_p
# self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]
# self._fp32_p_chunks [x self._num_chunks, self._chunk_size]
# same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g
#
# Usage:
#
# work = torch.distributed.reduce_scatter(self._flat_grads_blocks[block_id], self._fp16_g[block_id], ...)
# for chunk_id in range(self._num_chunks):
# work.wait()
# works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id], ...)
# or
# work.wait()
# works[0] = torch.distributed.all_reduce(self._fp16_g_blocks[block_id], ...)
#
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self._packed_flat_to_model_params = []
for shard_id in range(self._group_size):
for block_id in range(self._num_blocks):
flat_shard_start = (block_id * self._group_size + shard_id) * self._shard_size
flat_shard_end = flat_shard_start + self._shard_size
for p, grads_info in zip(self._model_params, self._grads_info):
flat_grad_start = grads_info["param_offset"]
flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)
if clipped_start < clipped_end:
grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_blocks[shard_id][block_id][shard_offset:shard_offset+grad_length]
self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group:
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_blocks[block_id][shard_offset:shard_offset+grad_length]
print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
master_param_fragment.copy_(model_param_fragment)
p_in, p_out = zip(*self._packed_flat_to_model_params)
self._packed_flat_to_model_params = [p_in, p_out]
self._distributed_weight_update = distributed_weight_update # Is this still needed?
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1: if self._num_groups > 1:
self._num_ar_pg = dwu_num_ar_pg
self._ar_pg = [] self._ar_pg = []
for dev_i in range(self._group_size): for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)] ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
...@@ -136,10 +303,9 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -136,10 +303,9 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
grp = torch.distributed.new_group(ranks=ranks) grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp) self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg: for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg) torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
self._num_rs_pg = dwu_num_rs_pg
rs_ranks = [] rs_ranks = []
for group_i in range(self._num_groups): for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)]) rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
...@@ -150,22 +316,39 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -150,22 +316,39 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
grp = torch.distributed.new_group(ranks=ranks) grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp) self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg: for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg) torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._redux_st = [torch.cuda.Stream() for _ in range(self._group_size)] self._ag_pg = self._rs_pg
self._compute_L2_grad_norm = compute_L2_grad_norm self._ag_st = self._rs_st
if self._compute_L2_grad_norm: self._num_ag_pg = self._num_rs_pg
self._L2_grad_norm = torch.zeros(size=[1],dtype=torch.float32).cuda() else:
self._l2_grad_norm_st = torch.cuda.Stream() self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None
self._completion_st = torch.cuda.Stream() self._completion_st = torch.cuda.Stream()
self._last_step = False self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def set_last_step(self, last_step): def set_last_step(self, last_step):
self._last_step = last_step self._last_step = last_step
def _get_flush_block(self): def _get_flush_block(self):
flush_block = [] flush_block = []
...@@ -187,77 +370,112 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -187,77 +370,112 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
return flush_block return flush_block
def _pipeline_block_reductions(self, block_id): def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0) self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
start = block_id * self._block_size # Reduction within each node
end = start + self._block_size # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
grad_block = self._flat_grads[start:end] # The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
active_rank = self._group_id*self._group_size+block_id rs_stream = self._rs_st[block_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
redux_stream = self._redux_st[block_id] with torch.cuda.stream(rs_stream):
redux_stream.wait_stream(torch.cuda.current_stream()) rs_work = torch.distributed.reduce_scatter(self._fp16_g_blocks[block_id],self._flat_grads_shards[block_id],group=self._rs_pg[block_id%self._num_rs_pg],async_op=True,no_copy=True)
with torch.cuda.stream(redux_stream): for chunk_id in range(self._num_chunks):
work = torch.distributed.reduce(grad_block,active_rank,group=self._rs_pg[block_id%self._num_rs_pg],async_op=True) works[chunk_id] = rs_work
if self._num_groups > 1 and self._rank == active_rank:
work.wait() # Reduction across nodes for each rank
work = torch.distributed.all_reduce(grad_block,group=self._ar_pg[block_id%self._num_ar_pg],async_op=True) if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
if self._compute_L2_grad_norm: glob_chunk_id = block_id * self._num_chunks + chunk_id
if self._rank == active_rank: ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(self._l2_grad_norm_st): with torch.cuda.stream(ar_stream):
work.wait() rs_work.wait()
self._L2_grad_norm = grad_block.norm(dtype=torch.float32,p=2)**2 works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
if block_id == 0:
with torch.cuda.stream(self._l2_grad_norm_st): # Optionally compute L2 grad norm
torch.distributed.all_reduce(self._L2_grad_norm,group=self._rs_pg[self._num_rs_pg-1]) if self._compute_L2_grad_norm and block_id == 0:
self._L2_grad_norm.sqrt_() with torch.cuda.stream(self._l2_grad_norm_st):
# FIXME: Does completion stream need to wait for L2 grad norm to finish? for block_id in range(self._num_blocks):
self._completion_st.wait_stream(self._l2_grad_norm_st) for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
with torch.cuda.stream(redux_stream): # Since the packed format is contiguous after reductions, only one norm is needed
work.wait() l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt()
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.adam(
p, p_copy, m, v, g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def _pipeline_block_step(self, block_id): def _pipeline_block_step(self, block_id):
active_rank = self._group_id*self._group_size+block_id # Call step kernel once per block
ag_stream = self._ag_st[block_id%self._num_ag_pg]
if self._rank == active_rank: with torch.cuda.stream(ag_stream):
redux_stream = self._redux_st[block_id] for chunk_id in range(self._num_chunks):
with torch.cuda.stream(redux_stream): self._reductions_works[block_id][chunk_id].wait()
self._partial_step_single_shard(block_id) self.__launch_step_kernel(
self._fp32_p_blocks[block_id],
self._fp16_p_blocks[block_id],
self._fp32_m_blocks[block_id],
self._fp32_v_blocks[block_id],
self._fp16_g_blocks[block_id])
# Call all-gather once per step.
# FIXME: Determine which is faster, one all-gather per block or a single all-gather at end
if block_id == 0: if block_id == 0:
new_params_blocks = [self._new_params[block*self._block_size:(block+1)*self._block_size] for block in range(self._group_size)] for other_ag_stream in self._ag_st:
for redux_stream in self._redux_st: self._completion_st.wait_stream(other_ag_stream)
self._completion_st.wait_stream(redux_stream)
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
torch.distributed.all_gather(new_params_blocks,new_params_blocks[self._rank_in_group],group=self._rs_pg[self._num_rs_pg-1],no_copy=True) torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _pipeline_step(self):
# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p,
self._fp16_p,
self._fp32_m,
self._fp32_v,
self._fp16_g)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale): def _flatten_grad_mt(self, scale):
grads = [] if self._flat_mt and len(self._grads) > 0:
flat_grads = []
for p_i, (grads_info, grad) in enumerate(zip(self._grads_info, self._grads)):
if grad is not None:
grads.append(grad)
flat_grads.append( self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]] )
self._grads = [None]*len(self._grads_info)
if len(grads) > 0:
self._overflow_buf.zero_() self._overflow_buf.zero_()
multi_tensor_applier( multi_tensor_applier(
amp_C.multi_tensor_scale, amp_C.multi_tensor_scale,
self._overflow_buf, self._overflow_buf,
[grads, flat_grads], list(zip(*self._grads)),
scale) scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad):
# handle overlapped reductions # handle overlapped reductions
self._grads[param_i] = grad.view(-1) if self._flat_mt:
self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])
self._grads_generated[param_i]=True self._grads_generated[param_i]=True
if not self._last_step: if not self._last_step:
if self._overlap_reductions: if self._overlap_reductions:
...@@ -315,130 +533,6 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -315,130 +533,6 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
else: else:
return None return None
# Distributed weight update algorithm:
# Model parameters are kept as-is.
# Gradients are flattened during backprop.
# Reductions are done with an intra-node reduce-scatter followed by an inter-node all-reduce.
# Step function is sharded and the shards are assembled with an intra-node all-gather.
# Sharded step function needs internal fp32 buffers for p, m and v.
# To save memory, we allocate the fp32 buffers to cover only the shards local GPU will update.
# This means we have to play around with indexes, which requires knowledge of block and shard number.
# Implement a method that performs a partial update of a single shard within a single block.
def _partial_step_single_shard(self, block_id, undo=False):
"""Perform step function for a single shard.
Arguments:
block_id (integer): Block index of shard [0,self._group_size>
undo (boolean, optional): If True, undo effect of previously called partial step.
"""
block_start = block_id * self._block_size
block_end = block_start + self._block_size
if self._fp32_p is None:
assert (not undo), "Tried to undo step before calling step."
# Allocate fp32 buffers on demand. Note that we don't make these part of the state
# since each rank only has partial buffers.
# To-Do:
self._fp32_p = torch.zeros([self._block_size]).float().cuda()
self._fp32_m = torch.zeros([self._block_size]).float().cuda()
self._fp32_v = torch.zeros([self._block_size]).float().cuda()
self._copy_to_fp32 = True
step = None
param_i = 0
for group in self.param_groups:
# compute combined scale factor for this group
combined_scale = self._global_scale
if group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if group['bias_correction'] else 0
group_start = -1
group_end = -2
for p in group['params']:
if not p.requires_grad:
continue
#if p.grad.is_sparse:
# raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if step is None:
# all we want from state at this point is state['step'], which should be the same for all p
step = state['step']
nels = p.numel()
offset = self._grads_info[param_i]['param_offset']
param_i += 1
start = offset
end = start + nels
clipped_start = start if start >= block_start else block_start
clipped_end = end if end <= block_end else block_end
# check if this parameter contributes to block
if clipped_start < clipped_end:
if group_start < 0:
group_start = clipped_start
group_end = clipped_end
if self._copy_to_fp32:
param_offset = clipped_start - block_start
param_size = clipped_end - clipped_start
buffer_start = param_offset
buffer_end = buffer_start + param_size
param_start = (clipped_start - start)
param_end = param_start + param_size
#assert (buffer_start >= 0 and buffer_end <= self._fp32_p.numel() and param_start >= 0 and param_end <= p.numel()), "Illegal copy"
self._fp32_p[buffer_start:buffer_end].copy_(p.view(-1)[param_start:param_end].float())
group_size = group_end - group_start
if group_size > 0:
assert (step is not None), "state['step'] is None for this parameter group"
group_offset = group_start - block_start
group_block_start = block_start + group_offset
group_block_end = group_block_start + group_size
group_buffer_start = group_offset
group_buffer_end = group_buffer_start + group_size
beta1, beta2 = group['betas']
if undo:
fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p[group_buffer_start:group_buffer_end],
self._fp32_m[group_buffer_start:group_buffer_end],
self._fp32_v[group_buffer_start:group_buffer_end],
self._flat_grads[group_block_start:group_block_end],
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step+1, # FIXME: Verify this should be step+1
self.eps_mode,
bias_correction,
group['weight_decay'])
else:
fused_adam_cuda.adam(
self._fp32_p[group_buffer_start:group_buffer_end],
self._new_params[group_block_start:group_block_end],
self._fp32_m[group_buffer_start:group_buffer_end],
self._fp32_v[group_buffer_start:group_buffer_end],
self._flat_grads[group_block_start:group_block_end],
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step+1,
self.eps_mode,
bias_correction,
group['weight_decay'])
def complete_reductions(self): def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed. """Complete reductions if full pipeline is not selected or overlap is not allowed.
""" """
...@@ -455,61 +549,73 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -455,61 +549,73 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
if self._last_step or not self._overlap_reductions: if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions # nothing done so far, run full pipeline after reductions
for block_id in range(self._group_size-1,-1,-1): for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_reductions(block_id) self._pipeline_block_reductions(block_id)
self._copy_to_fp32 = False if self._compute_L2_grad_norm:
self._current_block = self._group_size torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info) self._grads_generated = [False]*len(self._grads_info)
def revert_step(self): def revert_step(self):
"""Revert effect of previously calling partial_step. """Revert effect of previously calling partial_step.
""" """
self._partial_step_single_shard(self._rank_in_group, undo=True) # Call undo kernel once per step
combined_scale = self._global_scale
def step(self, closure=None): if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p,
self._fp32_m,
self._fp32_v,
self._fp16_g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def step(self, closure=None, skip_overflow_check=False):
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
if self._last_step or not self._overlap_reductions or not self._full_pipeline: if self._last_step or not self._overlap_reductions or not self._full_pipeline:
for block_id in range(self._group_size-1,-1,-1): self._pipeline_step()
self._pipeline_block_step(block_id)
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
# Check for overflow # Check for overflow
# Store state for loss scaler calculation # Store state for loss scaler calculation
self.strided_check_finite(self._new_params, stride=self._block_size, start=0, end=self._net_total_param_size) if skip_overflow_check:
has_overflow = self.peek_overflow has_overflow = False
else:
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
has_overflow = self.peek_overflow
if has_overflow: if has_overflow:
print("Reverting step") print("Reverting step")
self.revert_step() self.revert_step()
else: else:
# Copy self._new_params to model params # Copy self._new_params to model params
p_in = [] for p in self._model_params: self.state[p]['step'] += 1
p_out = [] multi_tensor_applier(
with torch.no_grad(): fused_adam_cuda.maybe_cast_mt,
param_i = 0 self._overflow_buf,
for group in self.param_groups: self._packed_flat_to_model_params)
for p in group['params']:
if not p.requires_grad:
continue
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['step'] += 1
nels = p.numel()
offset = self._grads_info[param_i]['param_offset']
p_in.append(self._new_params[offset:offset+nels].view_as(p))
p_out.append(p)
param_i += 1
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
[p_in, p_out]);
torch.cuda.current_stream().wait_stream(self._completion_st) torch.cuda.current_stream().wait_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment