Unverified Commit bd7e25a5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] grad scaler optional process group (#257)

parent 2df5ca2d
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging import logging
from typing import Dict from typing import Any, Dict
import torch import torch
from torch.cuda.amp import GradScaler as TorchGradScaler from torch.cuda.amp import GradScaler as TorchGradScaler
...@@ -30,9 +30,10 @@ class ShardedGradScaler(TorchGradScaler): ...@@ -30,9 +30,10 @@ class ShardedGradScaler(TorchGradScaler):
documentation https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler documentation https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
""" """
def __init__(self) -> None: def __init__(self, process_group: Any = dist.group.WORLD) -> None:
super().__init__() super().__init__()
self.display_warning = True self.display_warning = True
self.group = process_group
def unscale_(self, optimizer: Optimizer) -> None: def unscale_(self, optimizer: Optimizer) -> None:
# Could be a mistake, this scaler is supposed to work with ZeroRedundancyOptimizer only # Could be a mistake, this scaler is supposed to work with ZeroRedundancyOptimizer only
...@@ -48,7 +49,10 @@ class ShardedGradScaler(TorchGradScaler): ...@@ -48,7 +49,10 @@ class ShardedGradScaler(TorchGradScaler):
# Synchronize the detected inf across the ranks # Synchronize the detected inf across the ranks
optimizer_state = self._per_optimizer_states[id(optimizer)] optimizer_state = self._per_optimizer_states[id(optimizer)]
handles = [dist.all_reduce(v, async_op=True) for v in optimizer_state["found_inf_per_device"].values()] handles = [
dist.all_reduce(v, async_op=True, group=self.group)
for v in optimizer_state["found_inf_per_device"].values()
]
# Make sure that the calls are done before moving out # Make sure that the calls are done before moving out
_ = list(map(lambda x: x.wait(), handles)) _ = list(map(lambda x: x.wait(), handles))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment