Commit 4212b3e9 authored by Michael Carilli's avatar Michael Carilli
Browse files

Attempt to fix 97 (not sure why it's happening to begin with)

parent 197bcc48
...@@ -6,9 +6,20 @@ from . import amp ...@@ -6,9 +6,20 @@ from . import amp
try: try:
from . import optimizers from . import optimizers
except ImportError: except ImportError:
print("Warning: apex was installed without --cuda_ext. FusedAdam will be unavailable.") # An attempt to fix https://github.com/NVIDIA/apex/issues/97. I'm not sure why 97 is even
# happening because Python modules should only be imported once, even if import is called
# multiple times.
try:
_ = warned_optimizers
except NameError:
print("Warning: apex was installed without --cuda_ext. FusedAdam will be unavailable.")
warned_optimizers = True
try: try:
from . import normalization from . import normalization
except ImportError: except ImportError:
print("Warning: apex was installed without --cuda_ext. FusedLayerNorm will be unavailable.") try:
_ = warned_normalization
except NameError:
print("Warning: apex was installed without --cuda_ext. FusedLayerNorm will be unavailable.")
warned_normalization = True
...@@ -5,7 +5,11 @@ try: ...@@ -5,7 +5,11 @@ try:
import syncbn import syncbn
from .optimized_sync_batchnorm import SyncBatchNorm from .optimized_sync_batchnorm import SyncBatchNorm
except ImportError: except ImportError:
print("Warning: apex was installed without --cuda_ext. Fused syncbn kernels will be unavailable. Python fallbacks will be used instead.") try:
_ = warned_syncbn
except NameError:
print("Warning: apex was installed without --cuda_ext. Fused syncbn kernels will be unavailable. Python fallbacks will be used instead.")
warned_syncbn = True
from .sync_batchnorm import SyncBatchNorm from .sync_batchnorm import SyncBatchNorm
def convert_syncbn_model(module, process_group=None): def convert_syncbn_model(module, process_group=None):
......
...@@ -4,7 +4,11 @@ try: ...@@ -4,7 +4,11 @@ try:
from apex_C import flatten from apex_C import flatten
from apex_C import unflatten from apex_C import unflatten
except ImportError: except ImportError:
print("Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten.") try:
_ = warned_flatten
except NameError:
print("Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten.")
warned_flatten = True
from torch._utils import _flatten_dense_tensors as flatten from torch._utils import _flatten_dense_tensors as flatten
from torch._utils import _unflatten_dense_tensors as unflatten from torch._utils import _unflatten_dense_tensors as unflatten
import torch.distributed as dist import torch.distributed as dist
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment