"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "67f5cce2943e0f32f4e0c6b53177dc3107a7955f"
Commit 3f717d95 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix in internal pipelining

parent 17160f34
...@@ -46,7 +46,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -46,7 +46,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
compute_L2_grad_norm=False, distributed_weight_update=0, compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False, dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False,
dwu_num_chunks=4): dwu_num_chunks=4, predivide=True):
global fused_adam_cuda global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
...@@ -78,6 +78,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -78,6 +78,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._global_scale = None self._global_scale = None
self._num_blocks = dwu_num_blocks self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks self._num_chunks = dwu_num_chunks
self._predivide = predivide
self._full_pipeline = full_pipeline self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None
...@@ -160,7 +161,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -160,7 +161,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._blk_st = [] self._blk_st = []
for i in range(self._num_blk_st): for i in range(self._num_blk_st):
self._blk_st.append(torch.cuda.Stream()) self._blk_st.append(torch.cuda.Stream())
self._works = []
import inspect import inspect
if 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args: if 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args:
...@@ -197,19 +197,20 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -197,19 +197,20 @@ class DistributedFusedAdam(torch.optim.Optimizer):
end = start + self._block_size end = start + self._block_size
grad_block = flat_grads[start:end] grad_block = flat_grads[start:end]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)] grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
if self._pg_supports_no_copy: if self._pg_supports_no_copy:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,no_copy=True) work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,no_copy=True)
else: else:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True) work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True)
works = [work] works = [work]
if self._num_groups > 1: if self._num_groups > 1:
sliver_size = self._shard_size // self._num_chunks
assert ((sliver_size*self._num_chunks) == self._shard_size), "Shard size not a multiple of dwu_num_chunks"
works = []
work.wait() work.wait()
works = []
chunk_size = self._shard_size // self._num_chunks
for i in range(self._num_chunks): for i in range(self._num_chunks):
works.append( torch.distributed.all_reduce(grad_shards[self._rank_in_group][i*sliver_size:(i+1)*sliver_size],group=self._ar_pg[i%len(self._ar_pg)],async_op=True) ) chunks = [grad_shards[j][i*chunk_size:(i+1)*chunk_size] for j in range(self._group_size)]
work = torch.distributed.all_reduce(chunks[self._rank_in_group],group=self._ar_pg[i%len(self._ar_pg)],async_op=True)
works.append(work)
if self._compute_L2_grad_norm: if self._compute_L2_grad_norm:
with torch.cuda.stream(self._blk_st[0]): with torch.cuda.stream(self._blk_st[0]):
...@@ -224,7 +225,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -224,7 +225,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
torch.distributed.all_reduce(self._L2_grad_norm,group=self._rs_pg[0]) torch.distributed.all_reduce(self._L2_grad_norm,group=self._rs_pg[0])
self._L2_grad_norm.sqrt_() self._L2_grad_norm.sqrt_()
return works for work in works:
work.wait()
# NB! # NB!
# self._global_scale is used by this method. # self._global_scale is used by this method.
...@@ -234,21 +236,17 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -234,21 +236,17 @@ class DistributedFusedAdam(torch.optim.Optimizer):
new_params_shards = [new_params[start+shard_i*self._shard_size:start+(shard_i+1)*self._shard_size] for shard_i in range(self._group_size)] new_params_shards = [new_params[start+shard_i*self._shard_size:start+(shard_i+1)*self._shard_size] for shard_i in range(self._group_size)]
self._partial_step_single_shard(block_id) self._partial_step_single_shard(block_id)
if self._pg_supports_no_copy: if self._pg_supports_no_copy:
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True,no_copy=True) torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],no_copy=True)
else: else:
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True) torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)])
return work
def _pipeline_block(self, block_id, flat_grads, new_params): def _pipeline_block(self, block_id, flat_grads, new_params):
works = self._pipeline_block_reductions(block_id, flat_grads) self._pipeline_block_reductions(block_id, flat_grads)
for work in works: self._pipeline_block_step(block_id, flat_grads, new_params)
if work is not None:
work.wait()
return self._pipeline_block_step(block_id, flat_grads, new_params)
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad): def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad):
# handle overlapped reductions # handle overlapped reductions
torch.div(grad.view(-1), self._world_size, out=self._flat_grads[param_offset:param_offset+param_grads_size]) torch.div(grad.view(-1), self._world_size if self._predivide else 1.0, out=self._flat_grads[param_offset:param_offset+param_grads_size])
self._grads_generated[param_i]=True self._grads_generated[param_i]=True
if not self._last_step: if not self._last_step:
if self._overlap_reductions: if self._overlap_reductions:
...@@ -260,20 +258,12 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -260,20 +258,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self._full_pipeline: if self._full_pipeline:
if self._new_params is None: if self._new_params is None:
self._new_params = torch.zeros_like(self._flat_grads) self._new_params = torch.zeros_like(self._flat_grads)
work = self._pipeline_block(block_id, self._flat_grads, self._new_params) self._pipeline_block(block_id, self._flat_grads, self._new_params)
self._works.append(work)
else: else:
works = self._pipeline_block_reductions(block_id, self._flat_grads) self._pipeline_block_reductions(block_id, self._flat_grads)
self._works += works
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
def _wait_works(self):
for work in self._works:
if work is not None:
work.wait()
self._works = []
def set_global_scale(self, global_scale): def set_global_scale(self, global_scale):
"""Set global scale. """Set global scale.
""" """
...@@ -457,7 +447,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -457,7 +447,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def complete_reductions(self): def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed. """Complete reductions if full pipeline is not selected or overlap is not allowed.
""" """
self._wait_works()
if self._last_step: if self._last_step:
# zero out gradients that have not been completed yet # zero out gradients that have not been completed yet
...@@ -475,8 +464,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -475,8 +464,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
block_id = self._num_blocks - inv_block_id - 1 block_id = self._num_blocks - inv_block_id - 1
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream()) self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]): with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
works = self._pipeline_block_reductions(block_id, self._flat_grads) self._pipeline_block_reductions(block_id, self._flat_grads)
self._works += works
self._copy_to_fp32 = False self._copy_to_fp32 = False
self._decomp_stats = None self._decomp_stats = None
...@@ -486,7 +474,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -486,7 +474,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def revert_step(self): def revert_step(self):
"""Revert effect of previously calling partial_step. """Revert effect of previously calling partial_step.
""" """
self._wait_works()
for block_id in range(self._num_blocks): for block_id in range(self._num_blocks):
self._partial_step_single_shard(block_id, undo=True) self._partial_step_single_shard(block_id, undo=True)
...@@ -500,37 +487,38 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -500,37 +487,38 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._new_params = torch.zeros_like(self._flat_grads) self._new_params = torch.zeros_like(self._flat_grads)
for inv_block_id in range(self._num_blocks): for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1 block_id = self._num_blocks - inv_block_id - 1
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]): with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
work = self._pipeline_block_step(block_id, self._flat_grads, self._new_params) self._pipeline_block_step(block_id, self._flat_grads, self._new_params)
self._works.append(work)
with torch.cuda.stream(self._blk_st[0]):
self._wait_works() for i, blk_st in enumerate(self._blk_st):
torch.cuda.current_stream().wait_stream(blk_st)
# Check for overflow
# Store state for loss scaler calculation # Check for overflow
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size) # Store state for loss scaler calculation
if self.peek_overflow: self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
print("Reverting step") if self.peek_overflow:
self.revert_step() print("Reverting step")
else: self.revert_step()
# Copy self._new_params to model params else:
with torch.no_grad(): # Copy self._new_params to model params
param_i = 0 with torch.no_grad():
for group in self.param_groups: param_i = 0
for p in group['params']: for group in self.param_groups:
if not p.requires_grad: for p in group['params']:
continue if not p.requires_grad:
state = self.state[p] continue
if len(state) == 0: state = self.state[p]
state['step'] = 0 if len(state) == 0:
state['step'] += 1 state['step'] = 0
nels = p.numel() state['step'] += 1
offset = self._grads_info[param_i]['param_offset'] nels = p.numel()
p.set_(self._new_params[offset:offset+nels].view_as(p)) offset = self._grads_info[param_i]['param_offset']
param_i += 1 p.set_(self._new_params[offset:offset+nels].view_as(p))
self._new_params = None param_i += 1
self._new_params = None
torch.cuda.current_stream().wait_stream(self._blk_st[0])
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment