Unverified Commit 8e85ce8c authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Adding a GradScaler import guard for amp with pytorch 1.5 (#210)

parent 7a062894
...@@ -6,11 +6,16 @@ ...@@ -6,11 +6,16 @@
""" """
:mod:`fairscale.optim` is a package implementing various torch optimization algorithms. :mod:`fairscale.optim` is a package implementing various torch optimization algorithms.
""" """
import logging
from .adascale import AdaScale
from .oss import OSS
try: try:
from .adam import Adam, Precision from .adam import Adam, Precision
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
pass # pragma: no cover pass # pragma: no cover
from .adascale import AdaScale try:
from .grad_scaler import GradScaler from .grad_scaler import GradScaler
from .oss import OSS except ImportError:
logging.warning("Torch AMP is not available on this platform")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment