Commit 0b5dd020 authored by Michael Carilli's avatar Michael Carilli
Browse files

Adding option to ensure that model outputs are a desired type

parent 4dc711bc
......@@ -123,7 +123,7 @@ def wrap_fused_adam(optimizer, properties):
return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)
def _initialize(models, optimizers, properties, num_losses=1):
def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init
......@@ -164,7 +164,10 @@ def _initialize(models, optimizers, properties, num_losses=1):
model.to(properties.cast_model_type)
input_caster = functools.partial(to_type, properties.cast_model_type)
output_caster = functools.partial(to_type, torch.float32)
if cast_model_outputs is not None:
output_caster = functools.partial(to_type, cast_model_outputs)
else:
output_caster = functools.partial(to_type, torch.float32)
for model in models:
# Patch the forward method to cast incoming data to the correct type, and
......@@ -182,6 +185,17 @@ def _initialize(models, optimizers, properties, num_losses=1):
# State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict())
elif cast_model_outputs is not None:
output_caster = functools.partial(to_type, cast_model_outputs)
for model in models:
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
output = old_fwd(*args, **kwargs)
return applier(output, output_caster)
return new_fwd
model.forward = patch_forward(model.forward)
for i, optimizer in enumerate(optimizers):
# Still need to special case this for the first pass
......
......@@ -201,6 +201,7 @@ def initialize(
keep_batchnorm_fp32=None,
master_weights=None,
loss_scale=None,
cast_model_outputs=None,
num_losses=1,
verbosity=1,
):
......@@ -240,6 +241,8 @@ def initialize(
master_weights (bool, optional, default=None): Optional property override.
loss_scale (float or str, optional, default=None): Optional property override. If passed as a string,
must be a string representing a number, e.g., "128.0", or the string "dynamic".
cast_model_outputs (torch.dtype, optional, default=None): Option to ensure that the outputs
of your model(s) are always cast to a particular type regardless of ``opt_level``.
num_losses (int, optional, default=1): Option to tell Amp in advance how many losses/backward
passes you plan to use. When used in conjunction with the ``loss_id`` argument to
``amp.scale_loss``, enables Amp to use a different loss scale per loss/backward pass,
......@@ -333,7 +336,7 @@ def initialize(
for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:22} : {}".format(k, v), True)
return _initialize(models, optimizers, _amp_state.opt_properties, num_losses)
return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
# TODO: is this necessary/useful?
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment