Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
de3f3fea
Commit
de3f3fea
authored
May 08, 2020
by
rohithkrn
Browse files
add bfloat16 register functions, enable rnn functions, enable promote functions
parent
6e14df49
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
55 additions
and
30 deletions
+55
-30
apex/amp/__init__.py
apex/amp/__init__.py
+2
-2
apex/amp/_initialize.py
apex/amp/_initialize.py
+1
-1
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+2
-2
apex/amp/amp.py
apex/amp/amp.py
+17
-9
apex/amp/frontend.py
apex/amp/frontend.py
+5
-2
apex/amp/rnn_compat.py
apex/amp/rnn_compat.py
+2
-2
apex/amp/utils.py
apex/amp/utils.py
+8
-2
apex/amp/wrap.py
apex/amp/wrap.py
+18
-10
No files found.
apex/amp/__init__.py
View file @
de3f3fea
from
.amp
import
init
,
half_function
,
float_function
,
promote_function
,
\
from
.amp
import
init
,
half_function
,
bfloat16_function
,
float_function
,
promote_function
,
\
register_half_function
,
register_float_function
,
register_promote_function
register_half_function
,
register_bfloat16_function
,
register_float_function
,
register_promote_function
from
.handle
import
scale_loss
,
disable_casts
from
.handle
import
scale_loss
,
disable_casts
from
.frontend
import
initialize
,
state_dict
,
load_state_dict
from
.frontend
import
initialize
,
state_dict
,
load_state_dict
from
._amp_state
import
master_params
,
_amp_state
from
._amp_state
import
master_params
,
_amp_state
apex/amp/_initialize.py
View file @
de3f3fea
...
@@ -189,7 +189,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
...
@@ -189,7 +189,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
for
model
in
models
:
for
model
in
models
:
# Patch the forward method to cast incoming data to the correct type, and
# Patch the forward method to cast incoming data to the correct type, and
# outgoing data to float32, so "the user never needs to call .half()."
# outgoing data to float32, so "the user never needs to call .half()
/.bfloat16()
."
# I like writing things explicitly more than decorators.
# I like writing things explicitly more than decorators.
def
patch_forward
(
old_fwd
):
def
patch_forward
(
old_fwd
):
def
new_fwd
(
*
args
,
**
kwargs
):
def
new_fwd
(
*
args
,
**
kwargs
):
...
...
apex/amp/_process_optimizer.py
View file @
de3f3fea
...
@@ -213,8 +213,8 @@ def lazy_init_no_master_weights(self):
...
@@ -213,8 +213,8 @@ def lazy_init_no_master_weights(self):
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
stash
.
all_fp32_params
.
append
(
param
)
stash
.
all_fp32_params
.
append
(
param
)
else
:
else
:
raise
TypeError
(
"Optimizer's parameters must be
either
"
raise
TypeError
(
"Optimizer's parameters must be
one of
"
"torch.cuda.FloatTensor
or
torch.cuda.HalfTensor. "
"torch.cuda.FloatTensor
,
torch.cuda.HalfTensor
, torch.BFloat16Tensor
. "
"Received {}"
.
format
(
param
.
type
()))
"Received {}"
.
format
(
param
.
type
()))
stash
.
all_fp16_grad_stash
=
[
None
for
_
in
stash
.
all_fp16_params
]
stash
.
all_fp16_grad_stash
=
[
None
for
_
in
stash
.
all_fp16_params
]
...
...
apex/amp/amp.py
View file @
de3f3fea
...
@@ -30,6 +30,9 @@ def half_function(fn):
...
@@ -30,6 +30,9 @@ def half_function(fn):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
True
)
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
True
)
return
_decorator_helper
(
fn
,
utils
.
maybe_half
,
wrap_fn
)
return
_decorator_helper
(
fn
,
utils
.
maybe_half
,
wrap_fn
)
def
bfloat16_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
True
)
return
_decorator_helper
(
fn
,
utils
.
maybe_bfloat16
,
wrap_fn
)
def
float_function
(
fn
):
def
float_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
False
)
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
False
)
...
@@ -48,6 +51,11 @@ def register_half_function(module, name):
...
@@ -48,6 +51,11 @@ def register_half_function(module, name):
name
,
module
))
name
,
module
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_half
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_half
))
def
register_bfloat16_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
raise
ValueError
(
'No function named {} in module {}.'
.
format
(
name
,
module
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_bfloat16
))
def
register_float_function
(
module
,
name
):
def
register_float_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
...
@@ -116,11 +124,11 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
...
@@ -116,11 +124,11 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
# methods on FloatTensor, since they're distinct types.
# methods on FloatTensor, since they're distinct types.
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
():
for
fn
in
getattr
(
tensor_overrides
,
low_prec_funcs
):
for
fn
in
getattr
(
tensor_overrides
,
'FP16_FUNCS'
):
wrap
.
cached_cast
(
torch
.
cuda
.
FloatTensor
,
fn
,
utils
.
maybe_
low_prec
,
wrap
.
cached_cast
(
torch
.
cuda
.
FloatTensor
,
fn
,
utils
.
maybe_
half
,
handle
,
try_caching
=
True
,
verbose
=
verbose
)
handle
,
try_caching
=
True
,
verbose
=
verbose
)
for
fn
in
tensor_overrides
.
FP32_FUNCS
:
for
fn
in
tensor_overrides
.
FP32_FUNCS
:
wrap
.
cached_cast
(
low_prec_t
ensor
,
fn
,
utils
.
maybe_float
,
wrap
.
cached_cast
(
torch
.
cuda
.
HalfT
ensor
,
fn
,
utils
.
maybe_float
,
handle
,
try_caching
=
False
,
verbose
=
verbose
)
handle
,
try_caching
=
False
,
verbose
=
verbose
)
# 2) Enable type-promotion on multi-arg functions and methods.
# 2) Enable type-promotion on multi-arg functions and methods.
...
@@ -136,17 +144,17 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
...
@@ -136,17 +144,17 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
():
for
cls
,
(
list_name
,
promote_fn
)
in
itertools
.
product
([
torch
.
cuda
.
FloatTensor
,
for
cls
,
(
list_name
,
promote_fn
)
in
itertools
.
product
([
torch
.
cuda
.
FloatTensor
,
low_prec_t
ensor
],
torch
.
cuda
.
HalfT
ensor
],
promote_table
):
promote_table
):
for
fn
in
getattr
(
tensor_overrides
,
list_name
):
for
fn
in
getattr
(
tensor_overrides
,
list_name
):
promote_fn
(
cls
,
fn
,
handle
,
verbose
)
promote_fn
(
cls
,
fn
,
handle
,
verbose
)
# 3) For any in-place version of a blacklist function, error if any input is fp16.
# 3) For any in-place version of a blacklist function, error if any input is fp16
/bfloat16
.
# NB: this is overly conservative.
# NB: this is overly conservative.
for
fn
in
utils
.
as_inplace
(
torch_overrides
.
FP32_FUNCS
):
for
fn
in
utils
.
as_inplace
(
torch_overrides
.
FP32_FUNCS
):
wrap
.
err_if_any_half
(
torch_overrides
.
MODULE
,
fn
,
handle
)
wrap
.
err_if_any_half
(
torch_overrides
.
MODULE
,
fn
,
handle
)
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
# 3.5) For any in-place blacklist method, error if called on fp16
/bfloat16
tensor
for
fn
in
utils
.
as_inplace
(
tensor_overrides
.
FP32_FUNCS
):
for
fn
in
utils
.
as_inplace
(
tensor_overrides
.
FP32_FUNCS
):
wrap
.
err_if_arg0_half
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
wrap
.
err_if_arg0_half
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
():
...
@@ -158,7 +166,7 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
...
@@ -158,7 +166,7 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
tensor_overrides
.
CASTS
)):
tensor_overrides
.
CASTS
)):
wrap
.
promote_match_arg0
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
():
wrap
.
promote_match_arg0
(
low_prec_t
ensor
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
torch
.
cuda
.
HalfT
ensor
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
torch
.
cuda
.
FloatTensor
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
torch
.
cuda
.
FloatTensor
,
fn
,
handle
,
verbose
)
# 5) RNNs + RNN cells are whitelisted specially
# 5) RNNs + RNN cells are whitelisted specially
...
@@ -169,10 +177,10 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
...
@@ -169,10 +177,10 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
torch
.
nn
.
modules
.
rnn
.
_VF
=
rnn_compat
.
VariableFunctionsShim
()
torch
.
nn
.
modules
.
rnn
.
_VF
=
rnn_compat
.
VariableFunctionsShim
()
# Wrap all the rnns
# Wrap all the rnns
for
x
in
rnn_compat
.
RNN_NAMES
:
for
x
in
rnn_compat
.
RNN_NAMES
:
wrap
.
new_rnn_cast
(
x
.
upper
(),
handle
,
verbose
)
wrap
.
new_rnn_cast
(
x
.
upper
(),
maybe_low_prec
,
handle
,
verbose
)
# Wrap all the RNN cells
# Wrap all the RNN cells
rnn_compat
.
whitelist_rnn_cells
(
handle
,
verbose
)
rnn_compat
.
whitelist_rnn_cells
(
maybe_low_prec
,
handle
,
verbose
)
# 6) Place error+print message on banned functions.
# 6) Place error+print message on banned functions.
# Or, if allow_banned, then cast to FP32.
# Or, if allow_banned, then cast to FP32.
...
...
apex/amp/frontend.py
View file @
de3f3fea
...
@@ -16,6 +16,9 @@ class Properties(object):
...
@@ -16,6 +16,9 @@ class Properties(object):
"opt_level"
:
None
,
"opt_level"
:
None
,
"cast_model_type"
:
None
,
"cast_model_type"
:
None
,
"patch_torch_functions"
:
False
,
"patch_torch_functions"
:
False
,
# TODO: patch_torch_functions_type could probably be unified with
# patch_torch_functions. Currently introducing a new attribute
# to be on the safer side and not break stuff.
"patch_torch_functions_type"
:
None
,
"patch_torch_functions_type"
:
None
,
"keep_batchnorm_fp32"
:
None
,
"keep_batchnorm_fp32"
:
None
,
"master_weights"
:
None
,
"master_weights"
:
None
,
...
@@ -390,7 +393,7 @@ def initialize(
...
@@ -390,7 +393,7 @@ def initialize(
maybe_print
(
"Selected optimization level {}"
.
format
(
opt_levels
[
opt_level
].
brief
),
True
)
maybe_print
(
"Selected optimization level {}"
.
format
(
opt_levels
[
opt_level
].
brief
),
True
)
maybe_print
(
"Defaults for this optimization level are:"
,
True
)
maybe_print
(
"Defaults for this optimization level are:"
,
True
)
for
k
,
v
in
_amp_state
.
opt_properties
.
options
.
items
():
for
k
,
v
in
_amp_state
.
opt_properties
.
options
.
items
():
maybe_print
(
"{:2
2
} : {}"
.
format
(
k
,
v
),
True
)
maybe_print
(
"{:2
6
} : {}"
.
format
(
k
,
v
),
True
)
_amp_state
.
min_loss_scale
=
min_loss_scale
_amp_state
.
min_loss_scale
=
min_loss_scale
_amp_state
.
max_loss_scale
=
max_loss_scale
_amp_state
.
max_loss_scale
=
max_loss_scale
...
@@ -417,7 +420,7 @@ def initialize(
...
@@ -417,7 +420,7 @@ def initialize(
maybe_print
(
"After processing overrides, optimization options are:"
,
True
)
maybe_print
(
"After processing overrides, optimization options are:"
,
True
)
for
k
,
v
in
_amp_state
.
opt_properties
.
options
.
items
():
for
k
,
v
in
_amp_state
.
opt_properties
.
options
.
items
():
maybe_print
(
"{:2
2
} : {}"
.
format
(
k
,
v
),
True
)
maybe_print
(
"{:2
6
} : {}"
.
format
(
k
,
v
),
True
)
return
_initialize
(
models
,
optimizers
,
_amp_state
.
opt_properties
,
num_losses
,
cast_model_outputs
)
return
_initialize
(
models
,
optimizers
,
_amp_state
.
opt_properties
,
num_losses
,
cast_model_outputs
)
...
...
apex/amp/rnn_compat.py
View file @
de3f3fea
...
@@ -28,7 +28,7 @@ def has_old_rnns():
...
@@ -28,7 +28,7 @@ def has_old_rnns():
except
:
except
:
return
False
return
False
def
whitelist_rnn_cells
(
handle
,
verbose
):
def
whitelist_rnn_cells
(
cast_fn
,
handle
,
verbose
):
# Different module + function names in old/new RNN cases
# Different module + function names in old/new RNN cases
if
has_old_rnns
():
if
has_old_rnns
():
fn_names
=
[
'RNNReLUCell'
,
'RNNTanhCell'
,
'LSTMCell'
,
'GRUCell'
]
fn_names
=
[
'RNNReLUCell'
,
'RNNTanhCell'
,
'LSTMCell'
,
'GRUCell'
]
...
@@ -40,7 +40,7 @@ def whitelist_rnn_cells(handle, verbose):
...
@@ -40,7 +40,7 @@ def whitelist_rnn_cells(handle, verbose):
# Insert casts on cell functions
# Insert casts on cell functions
for
fn
in
fn_names
:
for
fn
in
fn_names
:
wrap
.
cached_cast
(
mod
,
fn
,
utils
.
maybe_half
,
handle
,
wrap
.
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
try_caching
=
True
,
verbose
=
verbose
)
try_caching
=
True
,
verbose
=
verbose
)
if
has_old_rnns
():
if
has_old_rnns
():
...
...
apex/amp/utils.py
View file @
de3f3fea
...
@@ -200,22 +200,28 @@ def synthesize_flattened_rnn_weights(fp32_weights,
...
@@ -200,22 +200,28 @@ def synthesize_flattened_rnn_weights(fp32_weights,
fp16_weights
.
append
(
fp16_layer_weights
)
fp16_weights
.
append
(
fp16_layer_weights
)
return
fp16_weights
return
fp16_weights
def
_str_from_dtype
(
dtype
=
torch
.
float16
):
type_to_str
=
{
torch
.
float16
:
'Half'
,
torch
.
bfloat16
:
'BFloat16'
}
return
type_to_str
[
dtype
]
# Roughly same as above, just the `fp32_weights` aren't nested.
# Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability.
# Code kept separate for readability.
def
new_synthesize_flattened_rnn_weights
(
fp32_weights
,
def
new_synthesize_flattened_rnn_weights
(
fp32_weights
,
fp16_flat_tensor
,
fp16_flat_tensor
,
rnn_fn
=
''
,
rnn_fn
=
''
,
dtype
=
torch
.
float16
,
verbose
=
False
):
verbose
=
False
):
fp16_weights
=
[]
fp16_weights
=
[]
fp32_base_ptr
=
fp32_weights
[
0
].
data_ptr
()
fp32_base_ptr
=
fp32_weights
[
0
].
data_ptr
()
for
w_fp32
in
fp32_weights
:
for
w_fp32
in
fp32_weights
:
w_fp16
=
w_fp32
.
new
().
half
(
)
w_fp16
=
w_fp32
.
new
().
to
(
dtype
=
dtype
)
offset
=
(
w_fp32
.
data_ptr
()
-
fp32_base_ptr
)
//
w_fp32
.
element_size
()
offset
=
(
w_fp32
.
data_ptr
()
-
fp32_base_ptr
)
//
w_fp32
.
element_size
()
w_fp16
.
set_
(
fp16_flat_tensor
.
storage
(),
w_fp16
.
set_
(
fp16_flat_tensor
.
storage
(),
offset
,
offset
,
w_fp32
.
shape
)
w_fp32
.
shape
)
w_fp16
.
copy_
(
w_fp32
)
w_fp16
.
copy_
(
w_fp32
)
if
verbose
:
if
verbose
:
print
(
'Float->
Half
({})'
.
format
(
rnn_fn
))
print
(
'Float->
{}
({})'
.
format
(
_str_from_dtype
(
dtype
),
rnn_fn
))
fp16_weights
.
append
(
w_fp16
)
fp16_weights
.
append
(
w_fp16
)
return
fp16_weights
return
fp16_weights
apex/amp/wrap.py
View file @
de3f3fea
...
@@ -51,7 +51,8 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
...
@@ -51,7 +51,8 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
if
len
(
types
)
<=
1
:
if
len
(
types
)
<=
1
:
return
orig_fn
(
*
args
,
**
kwargs
)
return
orig_fn
(
*
args
,
**
kwargs
)
elif
len
(
types
)
==
2
and
types
==
set
([
'HalfTensor'
,
'FloatTensor'
]):
elif
len
(
types
)
==
2
and
(
types
==
set
([
'HalfTensor'
,
'FloatTensor'
])
or
types
==
set
([
'BFloat16Tensor'
,
'FloatTensor'
])):
new_args
=
utils
.
casted_args
(
cast_fn
,
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
args
,
kwargs
)
kwargs
)
...
@@ -79,7 +80,8 @@ def sequence_promote(mod, fn, handle, verbose=False):
...
@@ -79,7 +80,8 @@ def sequence_promote(mod, fn, handle, verbose=False):
types
=
set
([
utils
.
type_string
(
x
)
for
x
in
seq
])
types
=
set
([
utils
.
type_string
(
x
)
for
x
in
seq
])
if
len
(
types
)
<=
1
:
if
len
(
types
)
<=
1
:
return
orig_fn
(
seq
,
*
args
,
**
kwargs
)
return
orig_fn
(
seq
,
*
args
,
**
kwargs
)
elif
types
==
set
([
'HalfTensor'
,
'FloatTensor'
]):
elif
(
types
==
set
([
'HalfTensor'
,
'FloatTensor'
])
or
types
==
set
([
'BFloat16Tensor'
,
'FloatTensor'
])):
cast_seq
=
utils
.
casted_args
(
maybe_float
,
cast_seq
=
utils
.
casted_args
(
maybe_float
,
seq
,
{})
seq
,
{})
return
orig_fn
(
cast_seq
,
*
args
,
**
kwargs
)
return
orig_fn
(
cast_seq
,
*
args
,
**
kwargs
)
...
@@ -121,12 +123,12 @@ def err_if_any_half(mod, fn, handle, custom_err_msg=None):
...
@@ -121,12 +123,12 @@ def err_if_any_half(mod, fn, handle, custom_err_msg=None):
@
functools
.
wraps
(
orig_fn
)
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
types
=
utils
.
collect_fp_tensor_types
(
args
,
kwargs
)
types
=
utils
.
collect_fp_tensor_types
(
args
,
kwargs
)
if
'HalfTensor'
in
types
:
if
'HalfTensor'
in
types
or
'BFloat16Tensor'
in
types
:
if
custom_err_msg
:
if
custom_err_msg
:
raise
NotImplementedError
(
custom_err_msg
)
raise
NotImplementedError
(
custom_err_msg
)
else
:
else
:
raise
NotImplementedError
(
'Cannot call in-place function '
+
raise
NotImplementedError
(
'Cannot call in-place function '
+
'{} with fp16
argument
s.'
.
format
(
fn
))
'{} with fp16
or bfloat16 arg
s.'
.
format
(
fn
))
else
:
else
:
return
orig_fn
(
*
args
,
**
kwargs
)
return
orig_fn
(
*
args
,
**
kwargs
)
utils
.
set_func_save
(
handle
,
mod
,
fn
,
wrapper
)
utils
.
set_func_save
(
handle
,
mod
,
fn
,
wrapper
)
...
@@ -139,9 +141,9 @@ def err_if_arg0_half(mod, fn, handle, verbose=False):
...
@@ -139,9 +141,9 @@ def err_if_arg0_half(mod, fn, handle, verbose=False):
@
functools
.
wraps
(
orig_fn
)
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
arg0
,
*
args
,
**
kwargs
):
def
wrapper
(
arg0
,
*
args
,
**
kwargs
):
assert
compat
.
is_tensor_like
(
arg0
)
assert
compat
.
is_tensor_like
(
arg0
)
if
utils
.
type_string
(
arg0
)
==
'HalfTensor'
:
if
utils
.
type_string
(
arg0
)
in
{
'HalfTensor'
,
'BFloat16Tensor'
}
:
raise
NotImplementedError
(
'Cannot call in-place method '
+
raise
NotImplementedError
(
'Cannot call in-place method '
+
'{}
on
fp16
Tensor
s.'
.
format
(
fn
))
'{}
with
fp16
or bfloat16 arg
s.'
.
format
(
fn
))
else
:
else
:
cast_fn
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
cast_fn
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
kwargs
)
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
kwargs
)
...
@@ -221,7 +223,7 @@ def rnn_cast(backend, fn, handle, verbose=False):
...
@@ -221,7 +223,7 @@ def rnn_cast(backend, fn, handle, verbose=False):
return
fwd_wrapper
return
fwd_wrapper
utils
.
set_func_save
(
handle
,
backend
,
fn
,
rnn_wrapper
)
utils
.
set_func_save
(
handle
,
backend
,
fn
,
rnn_wrapper
)
def
new_rnn_cast
(
fn
,
handle
,
verbose
=
False
):
def
new_rnn_cast
(
fn
,
cast_fn
,
handle
,
verbose
=
False
):
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# that _rnn_impls stashed. For rnn backend calls that directly invoke
# that _rnn_impls stashed. For rnn backend calls that directly invoke
...
@@ -234,7 +236,7 @@ def new_rnn_cast(fn, handle, verbose=False):
...
@@ -234,7 +236,7 @@ def new_rnn_cast(fn, handle, verbose=False):
assert
isinstance
(
mod
,
rnn_compat
.
VariableFunctionsShim
)
assert
isinstance
(
mod
,
rnn_compat
.
VariableFunctionsShim
)
fn
=
fn
.
lower
()
fn
=
fn
.
lower
()
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
cast_fn
=
utils
.
verbosify
(
utils
.
maybe_half
,
fn
,
verbose
)
cast_fn
=
utils
.
verbosify
(
cast_fn
,
fn
,
verbose
)
@
functools
.
wraps
(
orig_fn
)
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
# Exact call signature from modules/rnn.py
# Exact call signature from modules/rnn.py
...
@@ -249,14 +251,20 @@ def new_rnn_cast(fn, handle, verbose=False):
...
@@ -249,14 +251,20 @@ def new_rnn_cast(fn, handle, verbose=False):
else
:
else
:
params_idx
=
3
# PackedSequence case
params_idx
=
3
# PackedSequence case
if
cast_fn
==
utils
.
maybe_half
:
dtype
=
torch
.
half
elif
cast_fn
==
utils
.
maybe_bfloat16
:
dtype
=
torch
.
bfloat16
else
:
raise
RuntimeError
(
"Unsupported cast_fn passed. Supports only maybe_half and maybe_bfloat16"
)
new_args
=
[]
new_args
=
[]
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
if
i
==
params_idx
:
if
i
==
params_idx
:
num_params
=
sum
([
x
.
numel
()
for
x
in
arg
])
num_params
=
sum
([
x
.
numel
()
for
x
in
arg
])
fp16_weight_buf
=
args
[
0
].
new_empty
((
num_params
,),
fp16_weight_buf
=
args
[
0
].
new_empty
((
num_params
,),
dtype
=
torch
.
half
)
dtype
=
dtype
)
casted_weights
=
utils
.
new_synthesize_flattened_rnn_weights
(
casted_weights
=
utils
.
new_synthesize_flattened_rnn_weights
(
arg
,
fp16_weight_buf
,
fn
,
verbose
)
arg
,
fp16_weight_buf
,
fn
,
dtype
,
verbose
)
new_args
.
append
(
casted_weights
)
new_args
.
append
(
casted_weights
)
elif
utils
.
is_fp_tensor
(
arg
):
elif
utils
.
is_fp_tensor
(
arg
):
new_args
.
append
(
cast_fn
(
arg
))
new_args
.
append
(
cast_fn
(
arg
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment