Commit 65ca6177 authored by Michael Carilli's avatar Michael Carilli
Browse files

Fix for #186

parent d1f74a3e
...@@ -5,6 +5,15 @@ ...@@ -5,6 +5,15 @@
import os import os
import torch import torch
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 0:
import collections.abc as container_abcs
else:
from torch._six import container_abcs
class AmpState(object): class AmpState(object):
def __init__(self): def __init__(self):
self.hard_override=False self.hard_override=False
......
import torch import torch
from torch._six import container_abcs, string_classes from torch._six import string_classes
import functools import functools
from ._amp_state import _amp_state, warn_or_err from ._amp_state import _amp_state, warn_or_err, container_abcs
from .handle import disable_casts from .handle import disable_casts
from .scaler import LossScaler from .scaler import LossScaler
from apex.fp16_utils import convert_network from apex.fp16_utils import convert_network
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment