Commit 6b2ef787 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Remove dead code

parent d662f9ca
import types
import math import math
import torch import torch
import importlib import importlib
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdam(torch.optim.Optimizer): class DistributedFusedAdam(torch.optim.Optimizer):
...@@ -202,13 +200,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -202,13 +200,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def _pipeline_block_step(self, block_id, flat_grads, new_params): def _pipeline_block_step(self, block_id, flat_grads, new_params):
start = block_id * self._block_size start = block_id * self._block_size
end = start + self._block_size
grad_block = flat_grads[start:end]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
new_params_shards = [new_params[start+shard_i*self._shard_size:start+(shard_i+1)*self._shard_size] for shard_i in range(self._group_size)] new_params_shards = [new_params[start+shard_i*self._shard_size:start+(shard_i+1)*self._shard_size] for shard_i in range(self._group_size)]
shard_start = start + self._rank_in_group * self._shard_size
shard_end = shard_start + self._shard_size
block_id = start // self._block_size
self._partial_step_single_shard(block_id) self._partial_step_single_shard(block_id)
if self._pg_supports_no_copy: if self._pg_supports_no_copy:
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True,no_copy=True) work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True,no_copy=True)
...@@ -418,8 +410,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -418,8 +410,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def _do_compute_L2_grad_norm(self): def _do_compute_L2_grad_norm(self):
partial_sum = torch.zeros([]).cuda() partial_sum = torch.zeros([]).cuda()
for block in range(self._num_blocks): for block in range(self._num_blocks):
start = block * self._block_size
end = start + self._block_size
grad_block = self._flat_grads[block*self._block_size:(block+1)*self._block_size] grad_block = self._flat_grads[block*self._block_size:(block+1)*self._block_size]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)] grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
shard_grad_norm = grad_shards[self._rank_in_group].float().norm() shard_grad_norm = grad_shards[self._rank_in_group].float().norm()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment