Commit 5c1cf020 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Move partial_step out of complete reductions:

parent 3f4fb81f
......@@ -45,7 +45,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1):
dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
......@@ -78,7 +78,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._num_blocks = dwu_num_blocks
self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = None
self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
......@@ -202,6 +202,17 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self._num_groups > 1:
work.wait()
work = torch.distributed.all_reduce(grad_shards[self._rank_in_group],group=self._ar_pg[block_id%len(self._ar_pg)],async_op=True)
if self._compute_L2_grad_norm:
with torch.cuda.stream(self._blk_st[0]):
work.wait()
if block_id+1 == self._num_blocks:
self._L2_grad_norm = grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
elif block_id != 0:
self._L2_grad_norm += grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
else:
self._L2_grad_norm += grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
torch.distributed.all_reduce(self._L2_grad_norm,group=self._rs_pg[0])
self._L2_grad_norm.sqrt_()
return work
# NB!
......@@ -431,16 +442,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
bias_correction,
group['weight_decay'])
def _do_compute_L2_grad_norm(self):
partial_sum = torch.zeros([]).cuda()
for block in range(self._num_blocks):
grad_block = self._flat_grads[block*self._block_size:(block+1)*self._block_size]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
shard_grad_norm = grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)
partial_sum += (shard_grad_norm*shard_grad_norm)
torch.distributed.all_reduce(partial_sum,group=self._rs_pg[0], async_op=False)
self._L2_grad_norm = partial_sum.sqrt().item()
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
......@@ -456,49 +457,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._flat_grads[param_offset:param_offset+param_size].zero_()
self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions or not self._full_pipeline:
if self._new_params is None:
self._new_params = torch.zeros_like(self._flat_grads)
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
if self._compute_L2_grad_norm:
# do reductions, wait, complete L2, do step
for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
work = self._pipeline_block_reductions(block_id, self._flat_grads)
self._works.append(work)
self._wait_works()
self._do_compute_L2_grad_norm()
for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
work = self._pipeline_block_step(block_id, self._flat_grads, self._new_params)
self._works.append(work)
else:
# run full pipeline
for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
work = self._pipeline_block(block_id, self._flat_grads, self._new_params)
self._works.append(work)
else:
# reductions done.
if self._compute_L2_grad_norm:
self._do_compute_L2_grad_norm()
# do step
for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
work = self._pipeline_block_step(block_id, self._flat_grads, self._new_params)
self._works.append(work)
else:
if self._compute_L2_grad_norm:
self._do_compute_L2_grad_norm()
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
work = self._pipeline_block_reductions(block_id, self._flat_grads)
self._works.append(work)
self._copy_to_fp32 = False
self._decomp_stats = None
......@@ -517,6 +483,16 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if closure is not None:
loss = closure()
if self._last_step or not self._full_pipeline:
if self._new_params is None:
self._new_params = torch.zeros_like(self._flat_grads)
for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
work = self._pipeline_block_step(block_id, self._flat_grads, self._new_params)
self._works.append(work)
self._wait_works()
# Check for overflow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment