You need to sign in or sign up before continuing.
Commit ffed6e80 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Merge branch 'revertable_fused_adam_with_mt_support' of...

Merge branch 'revertable_fused_adam_with_mt_support' of https://github.com/NVIDIA/apex into revertable_fused_adam_with_mt_support

.
parents eca3a2c4 6b2ef787
import types
import math import math
import torch import torch
import importlib import importlib
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdam(torch.optim.Optimizer): class DistributedFusedAdam(torch.optim.Optimizer):
...@@ -202,13 +200,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -202,13 +200,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def _pipeline_block_step(self, block_id, flat_grads, new_params): def _pipeline_block_step(self, block_id, flat_grads, new_params):
start = block_id * self._block_size start = block_id * self._block_size
end = start + self._block_size
grad_block = flat_grads[start:end]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
new_params_shards = [new_params[start+shard_i*self._shard_size:start+(shard_i+1)*self._shard_size] for shard_i in range(self._group_size)] new_params_shards = [new_params[start+shard_i*self._shard_size:start+(shard_i+1)*self._shard_size] for shard_i in range(self._group_size)]
shard_start = start + self._rank_in_group * self._shard_size
shard_end = shard_start + self._shard_size
block_id = start // self._block_size
self._partial_step_single_shard(block_id) self._partial_step_single_shard(block_id)
if self._pg_supports_no_copy: if self._pg_supports_no_copy:
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True,no_copy=True) work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True,no_copy=True)
...@@ -418,8 +410,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -418,8 +410,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def _do_compute_L2_grad_norm(self): def _do_compute_L2_grad_norm(self):
partial_sum = torch.zeros([]).cuda() partial_sum = torch.zeros([]).cuda()
for block in range(self._num_blocks): for block in range(self._num_blocks):
start = block * self._block_size
end = start + self._block_size
grad_block = self._flat_grads[block*self._block_size:(block+1)*self._block_size] grad_block = self._flat_grads[block*self._block_size:(block+1)*self._block_size]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)] grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
shard_grad_norm = grad_shards[self._rank_in_group].float().norm() shard_grad_norm = grad_shards[self._rank_in_group].float().norm()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment