Unverified Commit 6802ad49 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] fixing adascale all_reduce (#155)

- Aurick noticed this bug and I ran into it yesterday
- after the fix, our cifar training shows same gain values from
  different replics now:

```
20-Oct-20 16:00:19 - DEBUG - rank1 - scale 2, gain ratio 1.3512124098087777
20-Oct-20 16:00:19 - DEBUG - rank0 - scale 2, gain ratio 1.3512124098087777
20-Oct-20 16:00:19 - DEBUG - rank1 - timing: data 0:00:00.000600 fwd 0:00:00.003678 loss 0:00:00.000086 bwd 0:00:00.314158 update 0:00:00.002132 rest 0:00:00.000399
20-Oct-20 16:00:19 - DEBUG - rank0 - timing: data 0:00:00.000643 fwd 0:00:00.003460 loss 0:00:00.000084 bwd 0:00:00.314678 update 0:00:00.002001 rest 0:00:00.000408
20-Oct-20 16:00:19 - DEBUG - rank1 - scale 2, gain ratio 1.3514997779980324
20-Oct-20 16:00:19 - DEBUG - rank0 - scale 2, gain ratio 1.3514997779980324
20-Oct-20 16:00:19 - DEBUG - rank1 - timing: data 0:00:00.000732 fwd 0:00:00.003689 loss 0:00:00.000086 bwd 0:00:00.314176 update 0:00:00.002146 rest 0:00:00.000397
20-Oct-20 16:00:19 - DEBUG - rank0 - timing: data 0:00:00.000646 fwd 0:00:00.003542 loss 0:00:00.000089 bwd 0:00:00.314549 update 0:00:00.001956 rest 0:00:00.000392
20-Oct-20 16:00:19 - DEBUG - rank1 - scale 2, gain ratio 1.352149646693932
20-Oct-20 16:00:19 - DEBUG - rank0 - scale 2, gain ratio 1.352149646693932
```
parent 6f8a8652
...@@ -185,7 +185,7 @@ class ShardedDataParallel(nn.Module): ...@@ -185,7 +185,7 @@ class ShardedDataParallel(nn.Module):
i_bucketed += 1 i_bucketed += 1
if i_bucketed > 0: if i_bucketed > 0:
buffer.div_(world_size) # type: ignore buffer.div_(world_size)
bucket_requests.append( bucket_requests.append(
( (
dist.reduce(tensor=buffer, dst=global_rank, group=group, async_op=True), # type: ignore dist.reduce(tensor=buffer, dst=global_rank, group=group, async_op=True), # type: ignore
...@@ -199,7 +199,7 @@ class ShardedDataParallel(nn.Module): ...@@ -199,7 +199,7 @@ class ShardedDataParallel(nn.Module):
if p.grad.requires_grad: if p.grad.requires_grad:
raise RuntimeError("DistributedDataParallel only works with gradients that don't require grad") raise RuntimeError("DistributedDataParallel only works with gradients that don't require grad")
p.grad.div_(world_size) # type: ignore p.grad.div_(world_size)
requests.append(dist.reduce(tensor=p.grad, dst=global_rank, group=group, async_op=True)) # type: ignore requests.append(dist.reduce(tensor=p.grad, dst=global_rank, group=group, async_op=True)) # type: ignore
# Unroll the initial packed small gradients, as soon as possible # Unroll the initial packed small gradients, as soon as possible
......
...@@ -199,7 +199,11 @@ class AdaScale(object): ...@@ -199,7 +199,11 @@ class AdaScale(object):
# gradients have been synchronized between each worker. # gradients have been synchronized between each worker.
self._final_callback_queued = False self._final_callback_queued = False
assert isinstance(self._local_grad_sqr, torch.Tensor) assert isinstance(self._local_grad_sqr, torch.Tensor)
torch.distributed.all_reduce(self._local_grad_sqr / self._world_size)
# self._local_grad_sqr is FP32, sum then div shouldn't overflow.
torch.distributed.all_reduce(self._local_grad_sqr) # SUM
self._local_grad_sqr.div_(self._world_size)
local_grad_sqr = self._local_grad_sqr.cpu().numpy() local_grad_sqr = self._local_grad_sqr.cpu().numpy()
total_grad_sqr = np.array( total_grad_sqr = np.array(
[sum(param.grad.pow(2).sum().item() for param in group["params"]) for group in self._optimizer.param_groups] [sum(param.grad.pow(2).sum().item() for param in group["params"]) for group in self._optimizer.param_groups]
...@@ -243,3 +247,7 @@ class AdaScale(object): ...@@ -243,3 +247,7 @@ class AdaScale(object):
return self.step(*args, **kwargs) return self.step(*args, **kwargs)
setattr(self._optimizer, "step", wrapper) setattr(self._optimizer, "step", wrapper)
def zero_grad(self) -> None:
"""Proxy function to optimizer"""
self._optimizer.zero_grad()
...@@ -348,6 +348,8 @@ class Tensor: ...@@ -348,6 +348,8 @@ class Tensor:
def digamma_(self) -> Tensor: ... def digamma_(self) -> Tensor: ...
def dim(self) -> _int: ... def dim(self) -> _int: ...
def dist(self, other: Tensor, p: Number=2) -> Tensor: ... def dist(self, other: Tensor, p: Number=2) -> Tensor: ...
def div(self, denominator: Number) -> Tensor: ...
def div_(self, denominator: Number) -> Tensor: ...
def dot(self, tensor: Tensor) -> Tensor: ... def dot(self, tensor: Tensor) -> Tensor: ...
def double(self) -> Tensor: ... def double(self) -> Tensor: ...
def eig(self, eigenvectors: _bool=False) -> Tuple[Tensor, Tensor]: ... def eig(self, eigenvectors: _bool=False) -> Tuple[Tensor, Tensor]: ...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment