Commit fab319f1 authored by Bram Vanroy's avatar Bram Vanroy Committed by mcarilli
Browse files

allow for non-distributed envs (Windows) (#531)

parent 753c427a
......@@ -2,7 +2,9 @@
import torch
import warnings
from . import parallel
if torch.distributed.is_available():
from . import parallel
from . import amp
from . import fp16_utils
......
......@@ -2,6 +2,7 @@ import torch
from torch._six import string_classes
import functools
import numpy as np
import sys
import warnings
from ._amp_state import _amp_state, warn_or_err, container_abcs
from .handle import disable_casts
......@@ -10,8 +11,10 @@ from ._process_optimizer import _process_optimizer
from apex.fp16_utils import convert_network
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..contrib.optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
from ..parallel import DistributedDataParallel as apex_DDP
from ..parallel.LARC import LARC
if torch.distributed.is_available():
from ..parallel import DistributedDataParallel as apex_DDP
from ..parallel.LARC import LARC
def to_type(dtype, t):
......@@ -62,7 +65,7 @@ def check_models(models):
parallel_type = None
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
parallel_type = "torch.nn.parallel.DistributedDataParallel"
if isinstance(model, apex_DDP):
if ('apex_DDP' in sys.modules) and isinstance(model, apex_DDP):
parallel_type = "apex.parallel.DistributedDataParallel"
if isinstance(model, torch.nn.parallel.DataParallel):
parallel_type = "torch.nn.parallel.DataParallel"
......@@ -139,11 +142,10 @@ class O2StateDictHook(object):
def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init
optimizers_was_list = False
if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC):
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in sys.modules and isinstance(optimizers, LARC)):
optimizers = [optimizers]
elif optimizers is None:
optimizers = []
......
import contextlib
import warnings
import sys
import torch
from . import utils
from .opt import OptimWrapper
from .scaler import LossScaler
from ._amp_state import _amp_state, master_params, maybe_print
from ..parallel.LARC import LARC
if torch.distributed.is_available():
from ..parallel.LARC import LARC
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
......@@ -84,7 +87,7 @@ def scale_loss(loss,
yield loss
return
if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC):
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in sys.modules and isinstance(optimizers, LARC)):
optimizers = [optimizers]
loss_scaler = _amp_state.loss_scalers[loss_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment