Unverified Commit 8e85ce8c authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Adding a GradScaler import guard for amp with pytorch 1.5 (#210)

parent 7a062894
......@@ -6,11 +6,16 @@
"""
:mod:`fairscale.optim` is a package implementing various torch optimization algorithms.
"""
import logging
from .adascale import AdaScale
from .oss import OSS
try:
from .adam import Adam, Precision
except ImportError: # pragma: no cover
pass # pragma: no cover
from .adascale import AdaScale
from .grad_scaler import GradScaler
from .oss import OSS
try:
from .grad_scaler import GradScaler
except ImportError:
logging.warning("Torch AMP is not available on this platform")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment