Commit fab319f1 authored by Bram Vanroy's avatar Bram Vanroy Committed by mcarilli
Browse files

allow for non-distributed envs (Windows) (#531)

parent 753c427a
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
import torch import torch
import warnings import warnings
from . import parallel if torch.distributed.is_available():
from . import parallel
from . import amp from . import amp
from . import fp16_utils from . import fp16_utils
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from torch._six import string_classes from torch._six import string_classes
import functools import functools
import numpy as np import numpy as np
import sys
import warnings import warnings
from ._amp_state import _amp_state, warn_or_err, container_abcs from ._amp_state import _amp_state, warn_or_err, container_abcs
from .handle import disable_casts from .handle import disable_casts
...@@ -10,8 +11,10 @@ from ._process_optimizer import _process_optimizer ...@@ -10,8 +11,10 @@ from ._process_optimizer import _process_optimizer
from apex.fp16_utils import convert_network from apex.fp16_utils import convert_network
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..contrib.optimizers import FP16_Optimizer as FP16_Optimizer_for_fused from ..contrib.optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
from ..parallel import DistributedDataParallel as apex_DDP
from ..parallel.LARC import LARC if torch.distributed.is_available():
from ..parallel import DistributedDataParallel as apex_DDP
from ..parallel.LARC import LARC
def to_type(dtype, t): def to_type(dtype, t):
...@@ -62,7 +65,7 @@ def check_models(models): ...@@ -62,7 +65,7 @@ def check_models(models):
parallel_type = None parallel_type = None
if isinstance(model, torch.nn.parallel.DistributedDataParallel): if isinstance(model, torch.nn.parallel.DistributedDataParallel):
parallel_type = "torch.nn.parallel.DistributedDataParallel" parallel_type = "torch.nn.parallel.DistributedDataParallel"
if isinstance(model, apex_DDP): if ('apex_DDP' in sys.modules) and isinstance(model, apex_DDP):
parallel_type = "apex.parallel.DistributedDataParallel" parallel_type = "apex.parallel.DistributedDataParallel"
if isinstance(model, torch.nn.parallel.DataParallel): if isinstance(model, torch.nn.parallel.DataParallel):
parallel_type = "torch.nn.parallel.DataParallel" parallel_type = "torch.nn.parallel.DataParallel"
...@@ -139,11 +142,10 @@ class O2StateDictHook(object): ...@@ -139,11 +142,10 @@ class O2StateDictHook(object):
def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None): def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init from .amp import init as amp_init
optimizers_was_list = False optimizers_was_list = False
if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC): if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in sys.modules and isinstance(optimizers, LARC)):
optimizers = [optimizers] optimizers = [optimizers]
elif optimizers is None: elif optimizers is None:
optimizers = [] optimizers = []
......
import contextlib import contextlib
import warnings import warnings
import sys
import torch import torch
from . import utils from . import utils
from .opt import OptimWrapper from .opt import OptimWrapper
from .scaler import LossScaler from .scaler import LossScaler
from ._amp_state import _amp_state, master_params, maybe_print from ._amp_state import _amp_state, master_params, maybe_print
from ..parallel.LARC import LARC
if torch.distributed.is_available():
from ..parallel.LARC import LARC
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls. # There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
...@@ -84,7 +87,7 @@ def scale_loss(loss, ...@@ -84,7 +87,7 @@ def scale_loss(loss,
yield loss yield loss
return return
if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC): if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in sys.modules and isinstance(optimizers, LARC)):
optimizers = [optimizers] optimizers = [optimizers]
loss_scaler = _amp_state.loss_scalers[loss_id] loss_scaler = _amp_state.loss_scalers[loss_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment