Unverified Commit bd7e25a5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] grad scaler optional process group (#257)

parent 2df5ca2d
......@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict
from typing import Any, Dict
import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
......@@ -30,9 +30,10 @@ class ShardedGradScaler(TorchGradScaler):
documentation https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
"""
def __init__(self) -> None:
def __init__(self, process_group: Any = dist.group.WORLD) -> None:
super().__init__()
self.display_warning = True
self.group = process_group
def unscale_(self, optimizer: Optimizer) -> None:
# Could be a mistake, this scaler is supposed to work with ZeroRedundancyOptimizer only
......@@ -48,7 +49,10 @@ class ShardedGradScaler(TorchGradScaler):
# Synchronize the detected inf across the ranks
optimizer_state = self._per_optimizer_states[id(optimizer)]
handles = [dist.all_reduce(v, async_op=True) for v in optimizer_state["found_inf_per_device"].values()]
handles = [
dist.all_reduce(v, async_op=True, group=self.group)
for v in optimizer_state["found_inf_per_device"].values()
]
# Make sure that the calls are done before moving out
_ = list(map(lambda x: x.wait(), handles))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment