"git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "18c42e67df28bb6c7f5dc847595637327919e5ea"
Commit 3f717d95 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix in internal pipelining

parent 17160f34
......@@ -46,7 +46,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False,
dwu_num_chunks=4):
dwu_num_chunks=4, predivide=True):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
......@@ -78,6 +78,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._predivide = predivide
self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None
......@@ -160,7 +161,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._blk_st = []
for i in range(self._num_blk_st):
self._blk_st.append(torch.cuda.Stream())
self._works = []
import inspect
if 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args:
......@@ -197,19 +197,20 @@ class DistributedFusedAdam(torch.optim.Optimizer):
end = start + self._block_size
grad_block = flat_grads[start:end]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
if self._pg_supports_no_copy:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,no_copy=True)
else:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True)
works = [work]
if self._num_groups > 1:
sliver_size = self._shard_size // self._num_chunks
assert ((sliver_size*self._num_chunks) == self._shard_size), "Shard size not a multiple of dwu_num_chunks"
works = []
work.wait()
works = []
chunk_size = self._shard_size // self._num_chunks
for i in range(self._num_chunks):
works.append( torch.distributed.all_reduce(grad_shards[self._rank_in_group][i*sliver_size:(i+1)*sliver_size],group=self._ar_pg[i%len(self._ar_pg)],async_op=True) )
chunks = [grad_shards[j][i*chunk_size:(i+1)*chunk_size] for j in range(self._group_size)]
work = torch.distributed.all_reduce(chunks[self._rank_in_group],group=self._ar_pg[i%len(self._ar_pg)],async_op=True)
works.append(work)
if self._compute_L2_grad_norm:
with torch.cuda.stream(self._blk_st[0]):
......@@ -224,7 +225,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
torch.distributed.all_reduce(self._L2_grad_norm,group=self._rs_pg[0])
self._L2_grad_norm.sqrt_()
return works
for work in works:
work.wait()
# NB!
# self._global_scale is used by this method.
......@@ -234,21 +236,17 @@ class DistributedFusedAdam(torch.optim.Optimizer):
new_params_shards = [new_params[start+shard_i*self._shard_size:start+(shard_i+1)*self._shard_size] for shard_i in range(self._group_size)]
self._partial_step_single_shard(block_id)
if self._pg_supports_no_copy:
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True,no_copy=True)
torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],no_copy=True)
else:
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True)
return work
torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)])
def _pipeline_block(self, block_id, flat_grads, new_params):
works = self._pipeline_block_reductions(block_id, flat_grads)
for work in works:
if work is not None:
work.wait()
return self._pipeline_block_step(block_id, flat_grads, new_params)
self._pipeline_block_reductions(block_id, flat_grads)
self._pipeline_block_step(block_id, flat_grads, new_params)
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad):
# handle overlapped reductions
torch.div(grad.view(-1), self._world_size, out=self._flat_grads[param_offset:param_offset+param_grads_size])
torch.div(grad.view(-1), self._world_size if self._predivide else 1.0, out=self._flat_grads[param_offset:param_offset+param_grads_size])
self._grads_generated[param_i]=True
if not self._last_step:
if self._overlap_reductions:
......@@ -260,20 +258,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self._full_pipeline:
if self._new_params is None:
self._new_params = torch.zeros_like(self._flat_grads)
work = self._pipeline_block(block_id, self._flat_grads, self._new_params)
self._works.append(work)
self._pipeline_block(block_id, self._flat_grads, self._new_params)
else:
works = self._pipeline_block_reductions(block_id, self._flat_grads)
self._works += works
self._pipeline_block_reductions(block_id, self._flat_grads)
flush_block = self._get_flush_block()
def _wait_works(self):
for work in self._works:
if work is not None:
work.wait()
self._works = []
def set_global_scale(self, global_scale):
"""Set global scale.
"""
......@@ -457,7 +447,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
self._wait_works()
if self._last_step:
# zero out gradients that have not been completed yet
......@@ -475,8 +464,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
block_id = self._num_blocks - inv_block_id - 1
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
works = self._pipeline_block_reductions(block_id, self._flat_grads)
self._works += works
self._pipeline_block_reductions(block_id, self._flat_grads)
self._copy_to_fp32 = False
self._decomp_stats = None
......@@ -486,7 +474,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def revert_step(self):
"""Revert effect of previously calling partial_step.
"""
self._wait_works()
for block_id in range(self._num_blocks):
self._partial_step_single_shard(block_id, undo=True)
......@@ -500,37 +487,38 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._new_params = torch.zeros_like(self._flat_grads)
for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
work = self._pipeline_block_step(block_id, self._flat_grads, self._new_params)
self._works.append(work)
self._wait_works()
# Check for overflow
# Store state for loss scaler calculation
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if self.peek_overflow:
print("Reverting step")
self.revert_step()
else:
# Copy self._new_params to model params
with torch.no_grad():
param_i = 0
for group in self.param_groups:
for p in group['params']:
if not p.requires_grad:
continue
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['step'] += 1
nels = p.numel()
offset = self._grads_info[param_i]['param_offset']
p.set_(self._new_params[offset:offset+nels].view_as(p))
param_i += 1
self._new_params = None
self._pipeline_block_step(block_id, self._flat_grads, self._new_params)
with torch.cuda.stream(self._blk_st[0]):
for i, blk_st in enumerate(self._blk_st):
torch.cuda.current_stream().wait_stream(blk_st)
# Check for overflow
# Store state for loss scaler calculation
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if self.peek_overflow:
print("Reverting step")
self.revert_step()
else:
# Copy self._new_params to model params
with torch.no_grad():
param_i = 0
for group in self.param_groups:
for p in group['params']:
if not p.requires_grad:
continue
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['step'] += 1
nels = p.numel()
offset = self._grads_info[param_i]['param_offset']
p.set_(self._new_params[offset:offset+nels].view_as(p))
param_i += 1
self._new_params = None
torch.cuda.current_stream().wait_stream(self._blk_st[0])
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment