Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
613997ea
Commit
613997ea
authored
Feb 26, 2019
by
Michael Carilli
Browse files
No need for casts during optimizer step
parent
ed8236fa
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
44 additions
and
124 deletions
+44
-124
apex/amp/__init__.py
apex/amp/__init__.py
+1
-1
apex/amp/_initialize.py
apex/amp/_initialize.py
+17
-4
apex/amp/handle.py
apex/amp/handle.py
+10
-1
apex/amp/wrap.py
apex/amp/wrap.py
+16
-0
tests/L0/run_amp/__init__.py
tests/L0/run_amp/__init__.py
+0
-0
tests/L0/run_amp/test_basic_casts.py
tests/L0/run_amp/test_basic_casts.py
+0
-0
tests/L0/run_amp/test_cache.py
tests/L0/run_amp/test_cache.py
+0
-0
tests/L0/run_amp/test_multi_tensor_scale.py
tests/L0/run_amp/test_multi_tensor_scale.py
+0
-0
tests/L0/run_amp/test_promotion.py
tests/L0/run_amp/test_promotion.py
+0
-0
tests/L0/run_amp/test_rnn.py
tests/L0/run_amp/test_rnn.py
+0
-0
tests/L0/run_amp/test_scale.py
tests/L0/run_amp/test_scale.py
+0
-0
tests/L0/run_amp/utils.py
tests/L0/run_amp/utils.py
+0
-0
tests/L0/run_fp16util/__init__.py
tests/L0/run_fp16util/__init__.py
+0
-0
tests/L0/run_fp16util/test_fp16util.py
tests/L0/run_fp16util/test_fp16util.py
+0
-0
tests/L0/run_mixed_adam/__init__.py
tests/L0/run_mixed_adam/__init__.py
+0
-0
tests/L0/run_mixed_adam/test_fp16_optimizer.py
tests/L0/run_mixed_adam/test_fp16_optimizer.py
+0
-0
tests/L0/run_mixed_adam/test_mixed_adam.py
tests/L0/run_mixed_adam/test_mixed_adam.py
+0
-0
tests/L0/run_test.py
tests/L0/run_test.py
+0
-0
tests/RNN/RNN_tests.py
tests/RNN/RNN_tests.py
+0
-118
tests/distributed/DDP/ddp_race_condition_test.py
tests/distributed/DDP/ddp_race_condition_test.py
+0
-0
No files found.
apex/amp/__init__.py
View file @
613997ea
from
.amp
import
init
,
half_function
,
float_function
,
promote_function
,
\
register_half_function
,
register_float_function
,
register_promote_function
from
.handle
import
scale_loss
from
.handle
import
scale_loss
,
disable_casts
from
.frontend
import
initialize
apex/amp/_initialize.py
View file @
613997ea
...
...
@@ -2,6 +2,7 @@ import torch
from
torch._six
import
container_abcs
,
string_classes
import
functools
from
._amp_state
import
_amp_state
from
.handle
import
disable_casts
from
.scaler
import
LossScaler
from
apex.fp16_utils
import
convert_network
from
..fp16_utils
import
FP16_Optimizer
as
FP16_Optimizer_general
...
...
@@ -111,8 +112,8 @@ def _initialize(models, optimizers, properties):
check_optimizers
(
optimizers
)
#
Stash master weights befo
re
ca
sting the model.
#
if properties.
master
_
weights
:
#
In the future, when FP16_Optimizer can be dep
reca
ted and master weights can
#
become an attribute, remember to stash
master
weights
before casting the model.
if
properties
.
cast_model_type
:
if
properties
.
keep_batchnorm_fp32
:
...
...
@@ -125,6 +126,7 @@ def _initialize(models, optimizers, properties):
caster
=
functools
.
partial
(
to_type
,
properties
.
cast_model_type
)
# Patch the forward method to cast incoming data to the correct type.
# I like writing things explicitly more than decorators.
def
patch_forward
(
old_fwd
):
def
new_fwd
(
*
args
,
**
kwargs
):
return
old_fwd
(
*
applier
(
args
,
caster
),
...
...
@@ -142,10 +144,10 @@ def _initialize(models, optimizers, properties):
if
isinstance
(
optimizer
,
FusedAdam
):
optimizers
[
i
]
=
wrap_fused_adam
(
optimizer
,
properties
)
if
properties
.
loss_scale
==
"dynamic"
:
optimizers
[
i
]
=
FP16_Optimizer_general
(
optimizer
s
[
i
]
,
optimizers
[
i
]
=
FP16_Optimizer_general
(
optimizer
,
dynamic_loss_scale
=
True
)
else
:
optimizers
[
i
]
=
FP16_Optimizer_general
(
optimizer
s
[
i
]
,
optimizers
[
i
]
=
FP16_Optimizer_general
(
optimizer
,
static_loss_scale
=
properties
.
loss_scale
)
else
:
for
optimizer
in
optimizers
:
...
...
@@ -154,6 +156,17 @@ def _initialize(models, optimizers, properties):
if
properties
.
patch_torch_functions
:
# handle is unused here. It's accessible later through a global value anyway.
handle
=
amp_init
(
loss_scale
=
properties
.
loss_scale
)
for
optimizer
in
optimizers
:
# Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway.
def
patch_step
(
old_step
):
def
new_step
(
*
args
,
**
kwargs
):
with
disable_casts
():
output
=
old_step
(
*
args
,
**
kwargs
)
return
output
return
new_step
optimizer
.
step
=
patch_step
(
optimizer
.
step
)
if
optimizers_was_list
:
if
models_was_list
:
...
...
apex/amp/handle.py
View file @
613997ea
...
...
@@ -50,7 +50,7 @@ def scale_loss(loss,
iter_params
(
optimizer
.
param_groups
),
iter_params
(
optimizer
.
param_groups
),
loss_scale
)
#
In the future, once I hav
e fused optimizers that enable sync-free dynamic loss scaling,
#
For futur
e fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False.
should_skip
=
optimizer
.
loss_scaler
.
update_scale
()
if
should_skip
:
...
...
@@ -66,6 +66,15 @@ def scale_loss(loss,
_amp_state
.
handle
.
_clear_cache
()
# Free function version of AmpHandle.disable_casts, another step on the
# path to removing the concept of "AmpHandle"
@
contextlib
.
contextmanager
def
disable_casts
():
_amp_state
.
handle
.
_is_active
=
False
yield
_amp_state
.
handle
.
_is_active
=
True
class
AmpHandle
(
object
):
def
__init__
(
self
,
loss_scale
=
"dynamic"
,
enable_caching
=
True
,
verbose
=
False
):
self
.
_enable_caching
=
enable_caching
...
...
apex/amp/wrap.py
View file @
613997ea
from
.
import
compat
from
.
import
utils
from
._amp_state
import
_amp_state
import
functools
...
...
@@ -37,10 +38,16 @@ def cached_cast(mod, fn, cast_fn, handle,
utils
.
set_func_save
(
handle
,
mod
,
fn
,
wrapper
)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
# Annoyingly, make_promote_wrapper still uses the global handle. Once everyone
# is on the new API and I am free to get rid of handle, I can clean this up.
def
make_promote_wrapper
(
orig_fn
,
cast_fn
,
handle
=
None
):
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
not
_amp_state
.
handle
.
is_active
():
return
orig_fn
(
*
args
,
**
kwargs
)
types
=
utils
.
collect_fp_tensor_types
(
args
,
kwargs
)
if
len
(
types
)
<=
1
:
return
orig_fn
(
*
args
,
**
kwargs
)
elif
len
(
types
)
==
2
and
types
==
set
([
'HalfTensor'
,
'FloatTensor'
]):
...
...
@@ -65,6 +72,9 @@ def sequence_promote(mod, fn, handle, verbose=False):
maybe_float
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
seq
,
*
args
,
**
kwargs
):
if
not
_amp_state
.
handle
.
is_active
():
return
orig_fn
(
seq
,
*
args
,
**
kwargs
)
types
=
set
([
utils
.
type_string
(
x
)
for
x
in
seq
])
if
len
(
types
)
<=
1
:
return
orig_fn
(
seq
,
*
args
,
**
kwargs
)
...
...
@@ -86,6 +96,9 @@ def promote_match_arg0(mod, fn, handle, verbose=False):
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
arg0
,
*
args
,
**
kwargs
):
assert
compat
.
is_tensor_like
(
arg0
)
if
not
_amp_state
.
handle
.
is_active
():
return
orig_fn
(
arg0
,
*
args
,
**
kwargs
)
if
utils
.
type_string
(
arg0
)
==
'HalfTensor'
:
cast_fn
=
utils
.
maybe_half
elif
utils
.
type_string
(
arg0
)
==
'FloatTensor'
:
...
...
@@ -215,6 +228,9 @@ def new_rnn_cast(fn, handle, verbose=False):
assert
len
(
args
)
==
9
assert
len
(
kwargs
)
==
0
if
not
_amp_state
.
handle
.
is_active
():
return
orig_fn
(
*
args
,
**
kwargs
)
if
isinstance
(
args
[
6
],
bool
):
params_idx
=
2
# Not PackedSequence case
else
:
...
...
tests/run_amp/__init__.py
→
tests/
L0/
run_amp/__init__.py
View file @
613997ea
File moved
tests/run_amp/test_basic_casts.py
→
tests/
L0/
run_amp/test_basic_casts.py
View file @
613997ea
File moved
tests/run_amp/test_cache.py
→
tests/
L0/
run_amp/test_cache.py
View file @
613997ea
File moved
tests/run_amp/test_multi_tensor_scale.py
→
tests/
L0/
run_amp/test_multi_tensor_scale.py
View file @
613997ea
File moved
tests/run_amp/test_promotion.py
→
tests/
L0/
run_amp/test_promotion.py
View file @
613997ea
File moved
tests/run_amp/test_rnn.py
→
tests/
L0/
run_amp/test_rnn.py
View file @
613997ea
File moved
tests/run_amp/test_scale.py
→
tests/
L0/
run_amp/test_scale.py
View file @
613997ea
File moved
tests/run_amp/utils.py
→
tests/
L0/
run_amp/utils.py
View file @
613997ea
File moved
tests/run_fp16util/__init__.py
→
tests/
L0/
run_fp16util/__init__.py
View file @
613997ea
File moved
tests/run_fp16util/test_fp16util.py
→
tests/
L0/
run_fp16util/test_fp16util.py
View file @
613997ea
File moved
tests/run_mixed_adam/__init__.py
→
tests/
L0/
run_mixed_adam/__init__.py
View file @
613997ea
File moved
tests/run_mixed_adam/test_fp16_optimizer.py
→
tests/
L0/
run_mixed_adam/test_fp16_optimizer.py
View file @
613997ea
File moved
tests/run_mixed_adam/test_mixed_adam.py
→
tests/
L0/
run_mixed_adam/test_mixed_adam.py
View file @
613997ea
File moved
tests/run_test.py
→
tests/
L0/
run_test.py
View file @
613997ea
File moved
tests/RNN/RNN_tests.py
deleted
100644 → 0
View file @
ed8236fa
import
torch
import
torch.nn
as
nn
from
torch.autograd
import
Variable
import
apex
from
apex.RNN.models
import
bidirectionalRNN
,
stackedRNN
,
RNNCell
from
torch.nn._functions.rnn
import
LSTMCell
import
itertools
torch
.
backends
.
cudnn
.
enabled
=
False
batch_first
=
False
#not implemented yet
dropout
=
0.0
#How to validate?
bidirectional
=
False
#True works, but differs in definition to PyTorch
rnn_types
=
[
'LSTM'
,
'GRU'
,
'ReLU'
,
'Tanh'
]
sizes
=
[
8
,
4
,
2
]
seq_sizes
=
sizes
hidden_sizes
=
sizes
inp_sizes
=
sizes
batch_sizes
=
sizes
num_layerss
=
sizes
biases
=
[
True
]
def
copy_param_set
(
pyt_rnn
,
my_rnn
,
layer
=
0
,
reverse
=
False
):
my_params
=
None
rnn
=
None
if
isinstance
(
my_rnn
,
bidirectionalRNN
):
rnn
=
my_rnn
.
fwd
.
rnns
[
layer
]
if
not
reverse
else
my_rnn
.
bckwrd
.
rnns
[
layer
]
elif
isinstance
(
my_rnn
,
stackedRNN
):
rnn
=
my_rnn
.
rnns
[
layer
]
else
:
raise
RuntimeError
()
param_names
=
[
'w_ih'
,
'w_hh'
,
'b_ih'
,
'b_hh'
]
if
not
hasattr
(
rnn
,
'b_hh'
):
param_names
=
param_names
[:
2
]
my_params
=
[
getattr
(
rnn
,
param_name
)
for
param_name
in
param_names
]
pyt_params
=
None
param_names
=
[
'weight_ih_'
,
'weight_hh_'
,
'bias_ih_'
,
'bias_hh_'
]
reverse_str
=
'_reverse'
if
reverse
else
''
if
not
hasattr
(
pyt_rnn
,
'bias_hh_l0'
):
param_names
=
param_names
[:
2
]
pyt_params
=
[
getattr
(
pyt_rnn
,
param_name
+
'l'
+
str
(
layer
)
+
reverse_str
)
for
param_name
in
param_names
]
for
pyt_param
,
my_param
in
zip
(
pyt_params
,
my_params
):
pyt_param
.
data
.
copy_
(
my_param
.
data
)
def
copy_all_params
(
pyt_rnn
,
my_rnn
):
for
layer
in
range
(
num_layers
):
copy_param_set
(
pyt_rnn
,
my_rnn
,
layer
)
if
bidirectional
:
copy_param_set
(
pyt_rnn
,
my_rnn
,
layer
,
bidirectional
)
def
compare_variables
(
v1
,
v2
,
msg
,
params
):
diff
=
float
((
v1
.
data
-
v2
.
data
).
abs
().
max
())
if
diff
>
1e-5
:
print
(
"Error of "
,
diff
,
" found for "
,
msg
,
" for case: "
,
str
(
params
))
def
compare_tuple_variables
(
t1
,
t2
,
msg
,
params
):
for
var1
,
var2
in
zip
(
t1
,
t2
):
compare_variables
(
var1
,
var2
,
msg
,
params
)
def
maybe_compare
(
v1
,
v2
,
msg
,
params
):
if
isinstance
(
v1
,
Variable
)
and
isinstance
(
v2
,
Variable
):
compare_variables
(
v1
,
v2
,
msg
,
params
)
else
:
compare_tuple_variables
(
v1
,
v2
,
msg
,
params
)
product
=
list
(
itertools
.
product
(
rnn_types
,
seq_sizes
,
hidden_sizes
,
inp_sizes
,
batch_sizes
,
num_layerss
,
biases
))
for
test_case
in
product
:
rnn_type
,
seq_size
,
hidden_size
,
inp_size
,
batch_size
,
num_layers
,
bias
=
test_case
inp
=
torch
.
cuda
.
FloatTensor
(
seq_size
,
batch_size
,
inp_size
).
uniform_
()
if
rnn_type
==
'ReLU'
or
rnn_type
==
'Tanh'
:
pytorch_rnn
=
nn
.
RNN
(
inp_size
,
hidden_size
,
num_layers
,
bias
,
batch_first
,
dropout
,
bidirectional
,
nonlinearity
=
rnn_type
.
lower
()).
cuda
()
else
:
pytorch_rnn
=
getattr
(
nn
,
rnn_type
)(
inp_size
,
hidden_size
,
num_layers
,
bias
,
batch_first
,
dropout
,
bidirectional
).
cuda
()
my_rnn
=
getattr
(
apex
.
RNN
.
models
,
rnn_type
)(
inp_size
,
hidden_size
,
num_layers
,
bias
,
batch_first
,
dropout
,
bidirectional
).
cuda
()
copy_all_params
(
pytorch_rnn
,
my_rnn
)
pyt_inp
=
Variable
(
inp
,
requires_grad
=
True
)
my_inp
=
Variable
(
inp
,
requires_grad
=
True
)
my_out
,
my_hiddens
=
my_rnn
(
my_inp
)
pyt_out
,
pyt_hiddens
=
pytorch_rnn
(
pyt_inp
)
pyt_out
.
sum
().
backward
()
my_out
.
sum
().
backward
()
maybe_compare
(
pyt_out
,
my_out
,
"out"
,
test_case
)
#If there's only one hidden state PyTorch doesn't return it in a tuple,
#apex does, so we wrap PyTorch's returned hidden state in a tuple.
if
not
isinstance
(
pyt_hiddens
,
tuple
):
pyt_hiddens
=
(
pyt_hiddens
,)
try
:
for
i
,
(
pyt_hid
,
my_hid
)
in
enumerate
(
zip
(
pyt_hiddens
,
my_hiddens
)):
maybe_compare
(
pyt_hid
,
my_hid
,
"hx_"
+
str
(
i
),
test_case
)
except
ValueError
:
maybe_compare
(
pyt_hiddens
,
my_hiddens
,
"hx_0"
,
test_case
)
maybe_compare
(
pyt_inp
.
grad
,
my_inp
.
grad
,
"inp.grad"
,
test_case
)
print
(
"Test passed."
)
tests/distributed/ddp_race_condition_test.py
→
tests/distributed/
DDP/
ddp_race_condition_test.py
View file @
613997ea
File moved
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment