Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d69011de
"...text-generation-inference.git" did not exist on "e58ad6dd66413ef34585348cdbac1664da391fa9"
Commit
d69011de
authored
Apr 16, 2019
by
Michael Carilli
Browse files
Adding option to ensure that model outputs are a desired type
parent
eea4c0aa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
3 deletions
+20
-3
apex/amp/_initialize.py
apex/amp/_initialize.py
+16
-2
apex/amp/frontend.py
apex/amp/frontend.py
+4
-1
No files found.
apex/amp/_initialize.py
View file @
d69011de
...
...
@@ -107,7 +107,7 @@ def check_optimizers(optimizers):
"on the specified opt_level (and optional overridden properties)."
)
def
_initialize
(
models
,
optimizers
,
properties
,
num_losses
=
1
):
def
_initialize
(
models
,
optimizers
,
properties
,
num_losses
=
1
,
cast_model_outputs
=
None
):
from
apex.parallel
import
DistributedDataParallel
as
apex_DDP
from
.amp
import
init
as
amp_init
...
...
@@ -148,7 +148,10 @@ def _initialize(models, optimizers, properties, num_losses=1):
model
.
to
(
properties
.
cast_model_type
)
input_caster
=
functools
.
partial
(
to_type
,
properties
.
cast_model_type
)
output_caster
=
functools
.
partial
(
to_type
,
torch
.
float32
)
if
cast_model_outputs
is
not
None
:
output_caster
=
functools
.
partial
(
to_type
,
cast_model_outputs
)
else
:
output_caster
=
functools
.
partial
(
to_type
,
torch
.
float32
)
for
model
in
models
:
# Patch the forward method to cast incoming data to the correct type, and
...
...
@@ -166,6 +169,17 @@ def _initialize(models, optimizers, properties, num_losses=1):
# State dict trick to recast any preexisting per-param state tensors
for
optimizer
in
optimizers
:
optimizer
.
load_state_dict
(
optimizer
.
state_dict
())
elif
cast_model_outputs
is
not
None
:
output_caster
=
functools
.
partial
(
to_type
,
cast_model_outputs
)
for
model
in
models
:
def
patch_forward
(
old_fwd
):
def
new_fwd
(
*
args
,
**
kwargs
):
output
=
old_fwd
(
*
args
,
**
kwargs
)
return
applier
(
output
,
output_caster
)
return
new_fwd
model
.
forward
=
patch_forward
(
model
.
forward
)
for
i
,
optimizer
in
enumerate
(
optimizers
):
optimizers
[
i
]
=
_process_optimizer
(
optimizer
,
properties
)
...
...
apex/amp/frontend.py
View file @
d69011de
...
...
@@ -201,6 +201,7 @@ def initialize(
keep_batchnorm_fp32
=
None
,
master_weights
=
None
,
loss_scale
=
None
,
cast_model_outputs
=
None
,
num_losses
=
1
,
verbosity
=
1
,
):
...
...
@@ -240,6 +241,8 @@ def initialize(
master_weights (bool, optional, default=None): Optional property override.
loss_scale (float or str, optional, default=None): Optional property override. If passed as a string,
must be a string representing a number, e.g., "128.0", or the string "dynamic".
cast_model_outputs (torch.dtype, optional, default=None): Option to ensure that the outputs
of your model(s) are always cast to a particular type regardless of ``opt_level``.
num_losses (int, optional, default=1): Option to tell Amp in advance how many losses/backward
passes you plan to use. When used in conjunction with the ``loss_id`` argument to
``amp.scale_loss``, enables Amp to use a different loss scale per loss/backward pass,
...
...
@@ -333,7 +336,7 @@ def initialize(
for
k
,
v
in
_amp_state
.
opt_properties
.
options
.
items
():
maybe_print
(
"{:22} : {}"
.
format
(
k
,
v
),
True
)
return
_initialize
(
models
,
optimizers
,
_amp_state
.
opt_properties
,
num_losses
)
return
_initialize
(
models
,
optimizers
,
_amp_state
.
opt_properties
,
num_losses
,
cast_model_outputs
)
# TODO: is this necessary/useful?
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment