"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c9bd4d433845921ddf7c0b0a50be3c7bdf7a80fc"
Unverified Commit c6f40418 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] ShardedGradScaler - remove the strict optimizer type requirement (#237)

* removing strict typing requirement, broken by ClassyVision
parent bb468670
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
from typing import Dict from typing import Dict
import torch import torch
...@@ -31,9 +32,16 @@ class ShardedGradScaler(TorchGradScaler): ...@@ -31,9 +32,16 @@ class ShardedGradScaler(TorchGradScaler):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.display_warning = True
def unscale_(self, optimizer: Optimizer) -> None: def unscale_(self, optimizer: Optimizer) -> None:
assert isinstance(optimizer, OSS), "ShardedGradScaler is to be used in combination with a sharded optimizer" # Could be a mistake, this scaler is supposed to work with ZeroRedundancyOptimizer only
if self.display_warning and not isinstance(optimizer, OSS):
logging.warning(
"ShardedGradScaler is to be used in combination with a sharded optimizer, this could not be checked"
)
self.display_warning = False # Only warn once
# Call the upstream unscale_ method which will only act on this rank's gradients # Call the upstream unscale_ method which will only act on this rank's gradients
super().unscale_(optimizer) super().unscale_(optimizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment