Commit d1f74a3e authored by Michael Carilli's avatar Michael Carilli
Browse files

Casting model output as well as input, for #195

parent 80185371
......@@ -16,10 +16,10 @@ def to_type(dtype, t):
if not t.is_cuda:
# This should not be a hard error, since it may be legitimate.
print("Warning: An input tensor was not cuda. ")
if t.requires_grad:
# This should be a hard-ish error.
warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP.")
# GANs require this.
# if t.requires_grad:
# warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
# "its gradients will not be properly allreduced by DDP.")
if t.is_floating_point():
return t.to(dtype)
return t
......@@ -155,17 +155,21 @@ def _initialize(models, optimizers, properties):
for model in models:
model.to(properties.cast_model_type)
caster = functools.partial(to_type, properties.cast_model_type)
# Patch the forward method to cast incoming data to the correct type.
# I like writing things explicitly more than decorators.
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
return old_fwd(*applier(args, caster),
**applier(kwargs, caster))
return new_fwd
model.forward = patch_forward(model.forward)
input_caster = functools.partial(to_type, properties.cast_model_type)
output_caster = functools.partial(to_type, torch.float32)
for model in models:
# Patch the forward method to cast incoming data to the correct type, and
# outgoing data to float32, so "the user never needs to call .half()."
# I like writing things explicitly more than decorators.
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
output = old_fwd(*applier(args, input_caster),
**applier(kwargs, input_caster))
return applier(output, output_caster)
return new_fwd
model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers:
......
......@@ -13,6 +13,10 @@ on the Github page.
GANs are a tricky case that many people have requested. A `comprehensive DCGAN example`_
is under construction.
If you already implemented Amp based on the instructions below, but it isn't behaving as expected,
please review `Advanced Amp Usage`_ to see if any topics match your use case. If that doesn't help,
file an issue.
``opt_level``\ s and Properties
-------------------------------
......@@ -55,6 +59,9 @@ In this way, there's no risk adhering to the Amp API, and a lot of potential per
.. _`comprehensive DCGAN example`:
https://github.com/NVIDIA/apex/tree/master/examples/dcgan
.. _`Advanced Amp Usage`:
https://nvidia.github.io/apex/advanced.html
Properties
**********
......
......@@ -68,7 +68,6 @@ parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--has-ext', action='store_true')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
......
#!/bin/bash
cp ../common/* .
bash run_test.sh single_gpu
bash run_test.sh single_gpu $1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment