Unverified Commit c6f40418 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] ShardedGradScaler - remove the strict optimizer type requirement (#237)

* removing strict typing requirement, broken by ClassyVision
parent bb468670
......@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict
import torch
......@@ -31,9 +32,16 @@ class ShardedGradScaler(TorchGradScaler):
def __init__(self) -> None:
super().__init__()
self.display_warning = True
def unscale_(self, optimizer: Optimizer) -> None:
assert isinstance(optimizer, OSS), "ShardedGradScaler is to be used in combination with a sharded optimizer"
# Could be a mistake, this scaler is supposed to work with ZeroRedundancyOptimizer only
if self.display_warning and not isinstance(optimizer, OSS):
logging.warning(
"ShardedGradScaler is to be used in combination with a sharded optimizer, this could not be checked"
)
self.display_warning = False # Only warn once
# Call the upstream unscale_ method which will only act on this rank's gradients
super().unscale_(optimizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment