Commit 1c464b48 authored by ptrblck's avatar ptrblck Committed by mcarilli
Browse files

move and fix check_optimizers (#268)

parent 28097c99
...@@ -111,7 +111,7 @@ def check_optimizers(optimizers): ...@@ -111,7 +111,7 @@ def check_optimizers(optimizers):
if isinstance(optim, FP16_Optimizer_for_fused): if isinstance(optim, FP16_Optimizer_for_fused):
bad_optim_type = "apex.optimizers.FP16_Optimizer" bad_optim_type = "apex.optimizers.FP16_Optimizer"
if bad_optim_type is not None: if bad_optim_type is not None:
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) + raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) +
"The optimizer(s) passed to amp.initialize() must be bare \n" "The optimizer(s) passed to amp.initialize() must be bare \n"
"instances of either ordinary Pytorch optimizers, or Apex fused \n" "instances of either ordinary Pytorch optimizers, or Apex fused \n"
"optimizers (currently just FusedAdam, but FusedSGD will be added \n" "optimizers (currently just FusedAdam, but FusedSGD will be added \n"
...@@ -148,7 +148,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -148,7 +148,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
optimizers = [] optimizers = []
elif isinstance(optimizers, list): elif isinstance(optimizers, list):
optimizers_was_list = True optimizers_was_list = True
check_optimizers(optimizers)
else: else:
check_optimizers([optimizers])
raise TypeError("optimizers must be either a single optimizer or a list of optimizers.") raise TypeError("optimizers must be either a single optimizer or a list of optimizers.")
if isinstance(models, torch.nn.Module): if isinstance(models, torch.nn.Module):
...@@ -164,7 +166,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -164,7 +166,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if not _amp_state.allow_incoming_model_not_fp32: if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models) check_params_fp32(models)
check_optimizers(optimizers)
# In the future, when FP16_Optimizer can be deprecated and master weights can # In the future, when FP16_Optimizer can be deprecated and master weights can
# become an attribute, remember to stash master weights before casting the model. # become an attribute, remember to stash master weights before casting the model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment