Commit 415e2646 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Perf improvement (less CPU work)

parent 9d6d2e01
...@@ -124,7 +124,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -124,7 +124,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
p_i += 1 p_i += 1
self._grads_generated = [False]*len(self._grads_info) self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt self._flat_mt = flat_mt
self._grads = [None]*len(self._grads_info) if self._flat_mt else None self._grads = []
if self._overlap_reductions: if self._overlap_reductions:
self._current_block = self._num_blocks self._current_block = self._num_blocks
...@@ -155,6 +155,10 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -155,6 +155,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda') self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda') self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
self._individual_flat_grads = []
for p_i, grads_info in enumerate(self._grads_info):
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]])
def _flat_split(p): def _flat_split(p):
def __blockify(p): def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)] return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
...@@ -393,26 +397,19 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -393,26 +397,19 @@ class DistributedFusedAdam(torch.optim.Optimizer):
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True) torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale): def _flatten_grad_mt(self, scale):
if self._flat_mt: if self._flat_mt and len(self._grads) > 0:
grads = []
flat_grads = []
for p_i, (grads_info, grad) in enumerate(zip(self._grads_info, self._grads)):
if grad is not None:
grads.append(grad)
flat_grads.append( self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]] )
self._grads = [None]*len(self._grads_info)
if len(grads) > 0:
self._overflow_buf.zero_() self._overflow_buf.zero_()
multi_tensor_applier( multi_tensor_applier(
amp_C.multi_tensor_scale, amp_C.multi_tensor_scale,
self._overflow_buf, self._overflow_buf,
[grads, flat_grads], list(zip(*self._grads)),
scale) scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad): def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad):
# handle overlapped reductions # handle overlapped reductions
if self._flat_mt: if self._flat_mt:
self._grads[param_i] = grad.view(-1) self._grads.append( (grad.view(-1), self._individual_flat_grads[param_i]) )
else: else:
torch.div(grad.view(-1), self._world_size if self._predivide else 1.0, out=self._flat_grads[param_offset:param_offset+param_grads_size]) torch.div(grad.view(-1), self._world_size if self._predivide else 1.0, out=self._flat_grads[param_offset:param_offset+param_grads_size])
self._grads_generated[param_i]=True self._grads_generated[param_i]=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment