Commit d1f74a3e authored by Michael Carilli's avatar Michael Carilli
Browse files

Casting model output as well as input, for #195

parent 80185371
...@@ -16,10 +16,10 @@ def to_type(dtype, t): ...@@ -16,10 +16,10 @@ def to_type(dtype, t):
if not t.is_cuda: if not t.is_cuda:
# This should not be a hard error, since it may be legitimate. # This should not be a hard error, since it may be legitimate.
print("Warning: An input tensor was not cuda. ") print("Warning: An input tensor was not cuda. ")
if t.requires_grad: # GANs require this.
# This should be a hard-ish error. # if t.requires_grad:
warn_or_err("input data requires grad. Since input data is not a model parameter,\n" # warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP.") # "its gradients will not be properly allreduced by DDP.")
if t.is_floating_point(): if t.is_floating_point():
return t.to(dtype) return t.to(dtype)
return t return t
...@@ -155,17 +155,21 @@ def _initialize(models, optimizers, properties): ...@@ -155,17 +155,21 @@ def _initialize(models, optimizers, properties):
for model in models: for model in models:
model.to(properties.cast_model_type) model.to(properties.cast_model_type)
caster = functools.partial(to_type, properties.cast_model_type) input_caster = functools.partial(to_type, properties.cast_model_type)
output_caster = functools.partial(to_type, torch.float32)
# Patch the forward method to cast incoming data to the correct type.
# I like writing things explicitly more than decorators. for model in models:
def patch_forward(old_fwd): # Patch the forward method to cast incoming data to the correct type, and
def new_fwd(*args, **kwargs): # outgoing data to float32, so "the user never needs to call .half()."
return old_fwd(*applier(args, caster), # I like writing things explicitly more than decorators.
**applier(kwargs, caster)) def patch_forward(old_fwd):
return new_fwd def new_fwd(*args, **kwargs):
output = old_fwd(*applier(args, input_caster),
model.forward = patch_forward(model.forward) **applier(kwargs, input_caster))
return applier(output, output_caster)
return new_fwd
model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors # State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers: for optimizer in optimizers:
......
...@@ -13,6 +13,10 @@ on the Github page. ...@@ -13,6 +13,10 @@ on the Github page.
GANs are a tricky case that many people have requested. A `comprehensive DCGAN example`_ GANs are a tricky case that many people have requested. A `comprehensive DCGAN example`_
is under construction. is under construction.
If you already implemented Amp based on the instructions below, but it isn't behaving as expected,
please review `Advanced Amp Usage`_ to see if any topics match your use case. If that doesn't help,
file an issue.
``opt_level``\ s and Properties ``opt_level``\ s and Properties
------------------------------- -------------------------------
...@@ -55,6 +59,9 @@ In this way, there's no risk adhering to the Amp API, and a lot of potential per ...@@ -55,6 +59,9 @@ In this way, there's no risk adhering to the Amp API, and a lot of potential per
.. _`comprehensive DCGAN example`: .. _`comprehensive DCGAN example`:
https://github.com/NVIDIA/apex/tree/master/examples/dcgan https://github.com/NVIDIA/apex/tree/master/examples/dcgan
.. _`Advanced Amp Usage`:
https://nvidia.github.io/apex/advanced.html
Properties Properties
********** **********
......
...@@ -68,7 +68,6 @@ parser.add_argument("--local_rank", default=0, type=int) ...@@ -68,7 +68,6 @@ parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true', parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.') help='enabling apex sync BN.')
parser.add_argument('--has-ext', action='store_true')
parser.add_argument('--opt-level', type=str) parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None)
......
#!/bin/bash #!/bin/bash
cp ../common/* . cp ../common/* .
bash run_test.sh single_gpu bash run_test.sh single_gpu $1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment