Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
de3f3fea
Commit
de3f3fea
authored
May 08, 2020
by
rohithkrn
Browse files
add bfloat16 register functions, enable rnn functions, enable promote functions
parent
6e14df49
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
55 additions
and
30 deletions
+55
-30
apex/amp/__init__.py
apex/amp/__init__.py
+2
-2
apex/amp/_initialize.py
apex/amp/_initialize.py
+1
-1
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+2
-2
apex/amp/amp.py
apex/amp/amp.py
+17
-9
apex/amp/frontend.py
apex/amp/frontend.py
+5
-2
apex/amp/rnn_compat.py
apex/amp/rnn_compat.py
+2
-2
apex/amp/utils.py
apex/amp/utils.py
+8
-2
apex/amp/wrap.py
apex/amp/wrap.py
+18
-10
No files found.
apex/amp/__init__.py
View file @
de3f3fea
from
.amp
import
init
,
half_function
,
float_function
,
promote_function
,
\
register_half_function
,
register_float_function
,
register_promote_function
from
.amp
import
init
,
half_function
,
bfloat16_function
,
float_function
,
promote_function
,
\
register_half_function
,
register_bfloat16_function
,
register_float_function
,
register_promote_function
from
.handle
import
scale_loss
,
disable_casts
from
.frontend
import
initialize
,
state_dict
,
load_state_dict
from
._amp_state
import
master_params
,
_amp_state
apex/amp/_initialize.py
View file @
de3f3fea
...
...
@@ -189,7 +189,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
for
model
in
models
:
# Patch the forward method to cast incoming data to the correct type, and
# outgoing data to float32, so "the user never needs to call .half()."
# outgoing data to float32, so "the user never needs to call .half()
/.bfloat16()
."
# I like writing things explicitly more than decorators.
def
patch_forward
(
old_fwd
):
def
new_fwd
(
*
args
,
**
kwargs
):
...
...
apex/amp/_process_optimizer.py
View file @
de3f3fea
...
...
@@ -213,8 +213,8 @@ def lazy_init_no_master_weights(self):
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
stash
.
all_fp32_params
.
append
(
param
)
else
:
raise
TypeError
(
"Optimizer's parameters must be
either
"
"torch.cuda.FloatTensor
or
torch.cuda.HalfTensor. "
raise
TypeError
(
"Optimizer's parameters must be
one of
"
"torch.cuda.FloatTensor
,
torch.cuda.HalfTensor
, torch.BFloat16Tensor
. "
"Received {}"
.
format
(
param
.
type
()))
stash
.
all_fp16_grad_stash
=
[
None
for
_
in
stash
.
all_fp16_params
]
...
...
apex/amp/amp.py
View file @
de3f3fea
...
...
@@ -30,6 +30,9 @@ def half_function(fn):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
True
)
return
_decorator_helper
(
fn
,
utils
.
maybe_half
,
wrap_fn
)
def
bfloat16_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
True
)
return
_decorator_helper
(
fn
,
utils
.
maybe_bfloat16
,
wrap_fn
)
def
float_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
False
)
...
...
@@ -48,6 +51,11 @@ def register_half_function(module, name):
name
,
module
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_half
))
def
register_bfloat16_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
raise
ValueError
(
'No function named {} in module {}.'
.
format
(
name
,
module
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_bfloat16
))
def
register_float_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
...
...
@@ -116,11 +124,11 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
# methods on FloatTensor, since they're distinct types.
if
compat
.
tensor_is_float_tensor
():
for
fn
in
getattr
(
tensor_overrides
,
low_prec_funcs
):
wrap
.
cached_cast
(
torch
.
cuda
.
FloatTensor
,
fn
,
utils
.
maybe_
low_prec
,
for
fn
in
getattr
(
tensor_overrides
,
'FP16_FUNCS'
):
wrap
.
cached_cast
(
torch
.
cuda
.
FloatTensor
,
fn
,
utils
.
maybe_
half
,
handle
,
try_caching
=
True
,
verbose
=
verbose
)
for
fn
in
tensor_overrides
.
FP32_FUNCS
:
wrap
.
cached_cast
(
low_prec_t
ensor
,
fn
,
utils
.
maybe_float
,
wrap
.
cached_cast
(
torch
.
cuda
.
HalfT
ensor
,
fn
,
utils
.
maybe_float
,
handle
,
try_caching
=
False
,
verbose
=
verbose
)
# 2) Enable type-promotion on multi-arg functions and methods.
...
...
@@ -136,17 +144,17 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if
compat
.
tensor_is_float_tensor
():
for
cls
,
(
list_name
,
promote_fn
)
in
itertools
.
product
([
torch
.
cuda
.
FloatTensor
,
low_prec_t
ensor
],
torch
.
cuda
.
HalfT
ensor
],
promote_table
):
for
fn
in
getattr
(
tensor_overrides
,
list_name
):
promote_fn
(
cls
,
fn
,
handle
,
verbose
)
# 3) For any in-place version of a blacklist function, error if any input is fp16.
# 3) For any in-place version of a blacklist function, error if any input is fp16
/bfloat16
.
# NB: this is overly conservative.
for
fn
in
utils
.
as_inplace
(
torch_overrides
.
FP32_FUNCS
):
wrap
.
err_if_any_half
(
torch_overrides
.
MODULE
,
fn
,
handle
)
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
# 3.5) For any in-place blacklist method, error if called on fp16
/bfloat16
tensor
for
fn
in
utils
.
as_inplace
(
tensor_overrides
.
FP32_FUNCS
):
wrap
.
err_if_arg0_half
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
if
compat
.
tensor_is_float_tensor
():
...
...
@@ -158,7 +166,7 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
tensor_overrides
.
CASTS
)):
wrap
.
promote_match_arg0
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
if
compat
.
tensor_is_float_tensor
():
wrap
.
promote_match_arg0
(
low_prec_t
ensor
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
torch
.
cuda
.
HalfT
ensor
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
torch
.
cuda
.
FloatTensor
,
fn
,
handle
,
verbose
)
# 5) RNNs + RNN cells are whitelisted specially
...
...
@@ -169,10 +177,10 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
torch
.
nn
.
modules
.
rnn
.
_VF
=
rnn_compat
.
VariableFunctionsShim
()
# Wrap all the rnns
for
x
in
rnn_compat
.
RNN_NAMES
:
wrap
.
new_rnn_cast
(
x
.
upper
(),
handle
,
verbose
)
wrap
.
new_rnn_cast
(
x
.
upper
(),
maybe_low_prec
,
handle
,
verbose
)
# Wrap all the RNN cells
rnn_compat
.
whitelist_rnn_cells
(
handle
,
verbose
)
rnn_compat
.
whitelist_rnn_cells
(
maybe_low_prec
,
handle
,
verbose
)
# 6) Place error+print message on banned functions.
# Or, if allow_banned, then cast to FP32.
...
...
apex/amp/frontend.py
View file @
de3f3fea
...
...
@@ -16,6 +16,9 @@ class Properties(object):
"opt_level"
:
None
,
"cast_model_type"
:
None
,
"patch_torch_functions"
:
False
,
# TODO: patch_torch_functions_type could probably be unified with
# patch_torch_functions. Currently introducing a new attribute
# to be on the safer side and not break stuff.
"patch_torch_functions_type"
:
None
,
"keep_batchnorm_fp32"
:
None
,
"master_weights"
:
None
,
...
...
@@ -390,7 +393,7 @@ def initialize(
maybe_print
(
"Selected optimization level {}"
.
format
(
opt_levels
[
opt_level
].
brief
),
True
)
maybe_print
(
"Defaults for this optimization level are:"
,
True
)
for
k
,
v
in
_amp_state
.
opt_properties
.
options
.
items
():
maybe_print
(
"{:2
2
} : {}"
.
format
(
k
,
v
),
True
)
maybe_print
(
"{:2
6
} : {}"
.
format
(
k
,
v
),
True
)
_amp_state
.
min_loss_scale
=
min_loss_scale
_amp_state
.
max_loss_scale
=
max_loss_scale
...
...
@@ -417,7 +420,7 @@ def initialize(
maybe_print
(
"After processing overrides, optimization options are:"
,
True
)
for
k
,
v
in
_amp_state
.
opt_properties
.
options
.
items
():
maybe_print
(
"{:2
2
} : {}"
.
format
(
k
,
v
),
True
)
maybe_print
(
"{:2
6
} : {}"
.
format
(
k
,
v
),
True
)
return
_initialize
(
models
,
optimizers
,
_amp_state
.
opt_properties
,
num_losses
,
cast_model_outputs
)
...
...
apex/amp/rnn_compat.py
View file @
de3f3fea
...
...
@@ -28,7 +28,7 @@ def has_old_rnns():
except
:
return
False
def
whitelist_rnn_cells
(
handle
,
verbose
):
def
whitelist_rnn_cells
(
cast_fn
,
handle
,
verbose
):
# Different module + function names in old/new RNN cases
if
has_old_rnns
():
fn_names
=
[
'RNNReLUCell'
,
'RNNTanhCell'
,
'LSTMCell'
,
'GRUCell'
]
...
...
@@ -40,7 +40,7 @@ def whitelist_rnn_cells(handle, verbose):
# Insert casts on cell functions
for
fn
in
fn_names
:
wrap
.
cached_cast
(
mod
,
fn
,
utils
.
maybe_half
,
handle
,
wrap
.
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
try_caching
=
True
,
verbose
=
verbose
)
if
has_old_rnns
():
...
...
apex/amp/utils.py
View file @
de3f3fea
...
...
@@ -200,22 +200,28 @@ def synthesize_flattened_rnn_weights(fp32_weights,
fp16_weights
.
append
(
fp16_layer_weights
)
return
fp16_weights
def
_str_from_dtype
(
dtype
=
torch
.
float16
):
type_to_str
=
{
torch
.
float16
:
'Half'
,
torch
.
bfloat16
:
'BFloat16'
}
return
type_to_str
[
dtype
]
# Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability.
def
new_synthesize_flattened_rnn_weights
(
fp32_weights
,
fp16_flat_tensor
,
rnn_fn
=
''
,
dtype
=
torch
.
float16
,
verbose
=
False
):
fp16_weights
=
[]
fp32_base_ptr
=
fp32_weights
[
0
].
data_ptr
()
for
w_fp32
in
fp32_weights
:
w_fp16
=
w_fp32
.
new
().
half
(
)
w_fp16
=
w_fp32
.
new
().
to
(
dtype
=
dtype
)
offset
=
(
w_fp32
.
data_ptr
()
-
fp32_base_ptr
)
//
w_fp32
.
element_size
()
w_fp16
.
set_
(
fp16_flat_tensor
.
storage
(),
offset
,
w_fp32
.
shape
)
w_fp16
.
copy_
(
w_fp32
)
if
verbose
:
print
(
'Float->
Half
({})'
.
format
(
rnn_fn
))
print
(
'Float->
{}
({})'
.
format
(
_str_from_dtype
(
dtype
),
rnn_fn
))
fp16_weights
.
append
(
w_fp16
)
return
fp16_weights
apex/amp/wrap.py
View file @
de3f3fea
...
...
@@ -51,7 +51,8 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
if
len
(
types
)
<=
1
:
return
orig_fn
(
*
args
,
**
kwargs
)
elif
len
(
types
)
==
2
and
types
==
set
([
'HalfTensor'
,
'FloatTensor'
]):
elif
len
(
types
)
==
2
and
(
types
==
set
([
'HalfTensor'
,
'FloatTensor'
])
or
types
==
set
([
'BFloat16Tensor'
,
'FloatTensor'
])):
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
kwargs
)
...
...
@@ -79,7 +80,8 @@ def sequence_promote(mod, fn, handle, verbose=False):
types
=
set
([
utils
.
type_string
(
x
)
for
x
in
seq
])
if
len
(
types
)
<=
1
:
return
orig_fn
(
seq
,
*
args
,
**
kwargs
)
elif
types
==
set
([
'HalfTensor'
,
'FloatTensor'
]):
elif
(
types
==
set
([
'HalfTensor'
,
'FloatTensor'
])
or
types
==
set
([
'BFloat16Tensor'
,
'FloatTensor'
])):
cast_seq
=
utils
.
casted_args
(
maybe_float
,
seq
,
{})
return
orig_fn
(
cast_seq
,
*
args
,
**
kwargs
)
...
...
@@ -121,12 +123,12 @@ def err_if_any_half(mod, fn, handle, custom_err_msg=None):
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
types
=
utils
.
collect_fp_tensor_types
(
args
,
kwargs
)
if
'HalfTensor'
in
types
:
if
'HalfTensor'
in
types
or
'BFloat16Tensor'
in
types
:
if
custom_err_msg
:
raise
NotImplementedError
(
custom_err_msg
)
else
:
raise
NotImplementedError
(
'Cannot call in-place function '
+
'{} with fp16
argument
s.'
.
format
(
fn
))
'{} with fp16
or bfloat16 arg
s.'
.
format
(
fn
))
else
:
return
orig_fn
(
*
args
,
**
kwargs
)
utils
.
set_func_save
(
handle
,
mod
,
fn
,
wrapper
)
...
...
@@ -139,9 +141,9 @@ def err_if_arg0_half(mod, fn, handle, verbose=False):
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
arg0
,
*
args
,
**
kwargs
):
assert
compat
.
is_tensor_like
(
arg0
)
if
utils
.
type_string
(
arg0
)
==
'HalfTensor'
:
if
utils
.
type_string
(
arg0
)
in
{
'HalfTensor'
,
'BFloat16Tensor'
}
:
raise
NotImplementedError
(
'Cannot call in-place method '
+
'{}
on
fp16
Tensor
s.'
.
format
(
fn
))
'{}
with
fp16
or bfloat16 arg
s.'
.
format
(
fn
))
else
:
cast_fn
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
kwargs
)
...
...
@@ -221,7 +223,7 @@ def rnn_cast(backend, fn, handle, verbose=False):
return
fwd_wrapper
utils
.
set_func_save
(
handle
,
backend
,
fn
,
rnn_wrapper
)
def
new_rnn_cast
(
fn
,
handle
,
verbose
=
False
):
def
new_rnn_cast
(
fn
,
cast_fn
,
handle
,
verbose
=
False
):
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# that _rnn_impls stashed. For rnn backend calls that directly invoke
...
...
@@ -234,7 +236,7 @@ def new_rnn_cast(fn, handle, verbose=False):
assert
isinstance
(
mod
,
rnn_compat
.
VariableFunctionsShim
)
fn
=
fn
.
lower
()
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
cast_fn
=
utils
.
verbosify
(
utils
.
maybe_half
,
fn
,
verbose
)
cast_fn
=
utils
.
verbosify
(
cast_fn
,
fn
,
verbose
)
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
# Exact call signature from modules/rnn.py
...
...
@@ -249,14 +251,20 @@ def new_rnn_cast(fn, handle, verbose=False):
else
:
params_idx
=
3
# PackedSequence case
if
cast_fn
==
utils
.
maybe_half
:
dtype
=
torch
.
half
elif
cast_fn
==
utils
.
maybe_bfloat16
:
dtype
=
torch
.
bfloat16
else
:
raise
RuntimeError
(
"Unsupported cast_fn passed. Supports only maybe_half and maybe_bfloat16"
)
new_args
=
[]
for
i
,
arg
in
enumerate
(
args
):
if
i
==
params_idx
:
num_params
=
sum
([
x
.
numel
()
for
x
in
arg
])
fp16_weight_buf
=
args
[
0
].
new_empty
((
num_params
,),
dtype
=
torch
.
half
)
dtype
=
dtype
)
casted_weights
=
utils
.
new_synthesize_flattened_rnn_weights
(
arg
,
fp16_weight_buf
,
fn
,
verbose
)
arg
,
fp16_weight_buf
,
fn
,
dtype
,
verbose
)
new_args
.
append
(
casted_weights
)
elif
utils
.
is_fp_tensor
(
arg
):
new_args
.
append
(
cast_fn
(
arg
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment