Commit 1c464b48 authored by ptrblck's avatar ptrblck Committed by mcarilli
Browse files

move and fix check_optimizers (#268)

parent 28097c99
......@@ -111,7 +111,7 @@ def check_optimizers(optimizers):
if isinstance(optim, FP16_Optimizer_for_fused):
bad_optim_type = "apex.optimizers.FP16_Optimizer"
if bad_optim_type is not None:
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) +
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) +
"The optimizer(s) passed to amp.initialize() must be bare \n"
"instances of either ordinary Pytorch optimizers, or Apex fused \n"
"optimizers (currently just FusedAdam, but FusedSGD will be added \n"
......@@ -148,7 +148,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
optimizers = []
elif isinstance(optimizers, list):
optimizers_was_list = True
check_optimizers(optimizers)
else:
check_optimizers([optimizers])
raise TypeError("optimizers must be either a single optimizer or a list of optimizers.")
if isinstance(models, torch.nn.Module):
......@@ -164,7 +166,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models)
check_optimizers(optimizers)
# In the future, when FP16_Optimizer can be deprecated and master weights can
# become an attribute, remember to stash master weights before casting the model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment