Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d1f74a3e
Commit
d1f74a3e
authored
Mar 12, 2019
by
Michael Carilli
Browse files
Casting model output as well as input, for #195
parent
80185371
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
17 deletions
+27
-17
apex/amp/_initialize.py
apex/amp/_initialize.py
+19
-15
docs/source/amp.rst
docs/source/amp.rst
+7
-0
examples/imagenet/main_amp.py
examples/imagenet/main_amp.py
+0
-1
tests/L1/cross_product/run.sh
tests/L1/cross_product/run.sh
+1
-1
No files found.
apex/amp/_initialize.py
View file @
d1f74a3e
...
@@ -16,10 +16,10 @@ def to_type(dtype, t):
...
@@ -16,10 +16,10 @@ def to_type(dtype, t):
if
not
t
.
is_cuda
:
if
not
t
.
is_cuda
:
# This should not be a hard error, since it may be legitimate.
# This should not be a hard error, since it may be legitimate.
print
(
"Warning: An input tensor was not cuda. "
)
print
(
"Warning: An input tensor was not cuda. "
)
if
t
.
requires_grad
:
# GANs require this.
# This should be a hard-ish error.
# if t.requires_grad:
warn_or_err
(
"input data requires grad. Since input data is not a model parameter,
\n
"
#
warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP."
)
#
"its gradients will not be properly allreduced by DDP.")
if
t
.
is_floating_point
():
if
t
.
is_floating_point
():
return
t
.
to
(
dtype
)
return
t
.
to
(
dtype
)
return
t
return
t
...
@@ -155,17 +155,21 @@ def _initialize(models, optimizers, properties):
...
@@ -155,17 +155,21 @@ def _initialize(models, optimizers, properties):
for
model
in
models
:
for
model
in
models
:
model
.
to
(
properties
.
cast_model_type
)
model
.
to
(
properties
.
cast_model_type
)
caster
=
functools
.
partial
(
to_type
,
properties
.
cast_model_type
)
input_caster
=
functools
.
partial
(
to_type
,
properties
.
cast_model_type
)
output_caster
=
functools
.
partial
(
to_type
,
torch
.
float32
)
# Patch the forward method to cast incoming data to the correct type.
# I like writing things explicitly more than decorators.
for
model
in
models
:
def
patch_forward
(
old_fwd
):
# Patch the forward method to cast incoming data to the correct type, and
def
new_fwd
(
*
args
,
**
kwargs
):
# outgoing data to float32, so "the user never needs to call .half()."
return
old_fwd
(
*
applier
(
args
,
caster
),
# I like writing things explicitly more than decorators.
**
applier
(
kwargs
,
caster
))
def
patch_forward
(
old_fwd
):
return
new_fwd
def
new_fwd
(
*
args
,
**
kwargs
):
output
=
old_fwd
(
*
applier
(
args
,
input_caster
),
model
.
forward
=
patch_forward
(
model
.
forward
)
**
applier
(
kwargs
,
input_caster
))
return
applier
(
output
,
output_caster
)
return
new_fwd
model
.
forward
=
patch_forward
(
model
.
forward
)
# State dict trick to recast any preexisting per-param state tensors
# State dict trick to recast any preexisting per-param state tensors
for
optimizer
in
optimizers
:
for
optimizer
in
optimizers
:
...
...
docs/source/amp.rst
View file @
d1f74a3e
...
@@ -13,6 +13,10 @@ on the Github page.
...
@@ -13,6 +13,10 @@ on the Github page.
GANs
are
a
tricky
case
that
many
people
have
requested
.
A
`
comprehensive
DCGAN
example
`
_
GANs
are
a
tricky
case
that
many
people
have
requested
.
A
`
comprehensive
DCGAN
example
`
_
is
under
construction
.
is
under
construction
.
If
you
already
implemented
Amp
based
on
the
instructions
below
,
but
it
isn
't behaving as expected,
please review `Advanced Amp Usage`_ to see if any topics match your use case. If that doesn'
t
help
,
file
an
issue
.
``
opt_level
``\
s
and
Properties
``
opt_level
``\
s
and
Properties
-------------------------------
-------------------------------
...
@@ -55,6 +59,9 @@ In this way, there's no risk adhering to the Amp API, and a lot of potential per
...
@@ -55,6 +59,9 @@ In this way, there's no risk adhering to the Amp API, and a lot of potential per
.. _`comprehensive DCGAN example`:
.. _`comprehensive DCGAN example`:
https://github.com/NVIDIA/apex/tree/master/examples/dcgan
https://github.com/NVIDIA/apex/tree/master/examples/dcgan
.. _`Advanced Amp Usage`:
https://nvidia.github.io/apex/advanced.html
Properties
Properties
**********
**********
...
...
examples/imagenet/main_amp.py
View file @
d1f74a3e
...
@@ -68,7 +68,6 @@ parser.add_argument("--local_rank", default=0, type=int)
...
@@ -68,7 +68,6 @@ parser.add_argument("--local_rank", default=0, type=int)
parser
.
add_argument
(
'--sync_bn'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--sync_bn'
,
action
=
'store_true'
,
help
=
'enabling apex sync BN.'
)
help
=
'enabling apex sync BN.'
)
parser
.
add_argument
(
'--has-ext'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--opt-level'
,
type
=
str
)
parser
.
add_argument
(
'--opt-level'
,
type
=
str
)
parser
.
add_argument
(
'--keep-batchnorm-fp32'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--keep-batchnorm-fp32'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--loss-scale'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--loss-scale'
,
type
=
str
,
default
=
None
)
...
...
tests/L1/cross_product/run.sh
View file @
d1f74a3e
#!/bin/bash
#!/bin/bash
cp
../common/
*
.
cp
../common/
*
.
bash run_test.sh single_gpu
bash run_test.sh single_gpu
$1
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment