Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
b2da92fc
Unverified
Commit
b2da92fc
authored
May 19, 2020
by
Peng
Committed by
GitHub
May 19, 2020
Browse files
Merge pull request #5 from rohithkrn/apex_amp_bfp16
Introduce new optimization levels for BFloat16 training
parents
65490af6
e1267a9a
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
230 additions
and
76 deletions
+230
-76
apex/amp/__init__.py
apex/amp/__init__.py
+2
-2
apex/amp/_initialize.py
apex/amp/_initialize.py
+8
-6
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+13
-13
apex/amp/amp.py
apex/amp/amp.py
+30
-9
apex/amp/compat.py
apex/amp/compat.py
+2
-1
apex/amp/frontend.py
apex/amp/frontend.py
+80
-13
apex/amp/lists/functional_overrides.py
apex/amp/lists/functional_overrides.py
+11
-0
apex/amp/lists/tensor_overrides.py
apex/amp/lists/tensor_overrides.py
+5
-1
apex/amp/lists/torch_overrides.py
apex/amp/lists/torch_overrides.py
+21
-0
apex/amp/rnn_compat.py
apex/amp/rnn_compat.py
+2
-2
apex/amp/utils.py
apex/amp/utils.py
+19
-2
apex/amp/wrap.py
apex/amp/wrap.py
+20
-10
csrc/multi_tensor_adam.cu
csrc/multi_tensor_adam.cu
+1
-1
csrc/multi_tensor_axpby_kernel.cu
csrc/multi_tensor_axpby_kernel.cu
+3
-3
csrc/multi_tensor_l2norm_kernel.cu
csrc/multi_tensor_l2norm_kernel.cu
+3
-3
csrc/multi_tensor_lamb.cu
csrc/multi_tensor_lamb.cu
+2
-2
csrc/multi_tensor_lamb_stage_1.cu
csrc/multi_tensor_lamb_stage_1.cu
+3
-3
csrc/multi_tensor_lamb_stage_2.cu
csrc/multi_tensor_lamb_stage_2.cu
+2
-2
csrc/multi_tensor_novograd.cu
csrc/multi_tensor_novograd.cu
+1
-1
csrc/multi_tensor_scale_kernel.cu
csrc/multi_tensor_scale_kernel.cu
+2
-2
No files found.
apex/amp/__init__.py
View file @
b2da92fc
from
.amp
import
init
,
half_function
,
float_function
,
promote_function
,
\
from
.amp
import
init
,
half_function
,
bfloat16_function
,
float_function
,
promote_function
,
\
register_half_function
,
register_float_function
,
register_promote_function
register_half_function
,
register_bfloat16_function
,
register_float_function
,
register_promote_function
from
.handle
import
scale_loss
,
disable_casts
from
.handle
import
scale_loss
,
disable_casts
from
.frontend
import
initialize
,
state_dict
,
load_state_dict
from
.frontend
import
initialize
,
state_dict
,
load_state_dict
from
._amp_state
import
master_params
,
_amp_state
from
._amp_state
import
master_params
,
_amp_state
apex/amp/_initialize.py
View file @
b2da92fc
...
@@ -80,10 +80,10 @@ def check_params_fp32(models):
...
@@ -80,10 +80,10 @@ def check_params_fp32(models):
for
model
in
models
:
for
model
in
models
:
for
name
,
param
in
model
.
named_parameters
():
for
name
,
param
in
model
.
named_parameters
():
if
param
.
is_floating_point
():
if
param
.
is_floating_point
():
if
'Half'
in
param
.
type
():
if
'Half'
in
param
.
type
()
or
'BFloat16'
in
param
.
type
():
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.initialize, you do not need to call .half() o
n your model
\n
"
"When using amp.initialize, you do not need to call .half() o
r .bfloat16()
\n
"
"before passing it, no matter what optimization level you choose."
.
format
(
"
on your model
before passing it, no matter what optimization level you choose."
.
format
(
name
,
param
.
type
()))
name
,
param
.
type
()))
elif
not
param
.
is_cuda
:
elif
not
param
.
is_cuda
:
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
...
@@ -137,7 +137,7 @@ class O2StateDictHook(object):
...
@@ -137,7 +137,7 @@ class O2StateDictHook(object):
def
__call__
(
self
,
module
,
state_dict
,
prefix
,
local_metadata
):
def
__call__
(
self
,
module
,
state_dict
,
prefix
,
local_metadata
):
for
key
in
state_dict
:
for
key
in
state_dict
:
param
=
state_dict
[
key
]
param
=
state_dict
[
key
]
if
'Half'
in
param
.
type
():
if
'Half'
in
param
.
type
()
or
'BFloat16'
in
param
.
type
():
param
=
param
.
to
(
torch
.
float32
)
param
=
param
.
to
(
torch
.
float32
)
state_dict
[
key
]
=
param
state_dict
[
key
]
=
param
...
@@ -189,7 +189,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
...
@@ -189,7 +189,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
for
model
in
models
:
for
model
in
models
:
# Patch the forward method to cast incoming data to the correct type, and
# Patch the forward method to cast incoming data to the correct type, and
# outgoing data to float32, so "the user never needs to call .half()."
# outgoing data to float32, so "the user never needs to call .half()
/.bfloat16()
."
# I like writing things explicitly more than decorators.
# I like writing things explicitly more than decorators.
def
patch_forward
(
old_fwd
):
def
patch_forward
(
old_fwd
):
def
new_fwd
(
*
args
,
**
kwargs
):
def
new_fwd
(
*
args
,
**
kwargs
):
...
@@ -232,7 +232,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
...
@@ -232,7 +232,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if
properties
.
patch_torch_functions
:
if
properties
.
patch_torch_functions
:
# handle is unused here. It's accessible later through a global value anyway.
# handle is unused here. It's accessible later through a global value anyway.
handle
=
amp_init
(
loss_scale
=
properties
.
loss_scale
,
verbose
=
(
_amp_state
.
verbosity
==
2
))
handle
=
amp_init
(
loss_scale
=
properties
.
loss_scale
,
patch_type
=
properties
.
patch_torch_functions_type
,
verbose
=
(
_amp_state
.
verbosity
==
2
))
for
optimizer
in
optimizers
:
for
optimizer
in
optimizers
:
# Disable Amp casting for the optimizer step, because it should only be
# Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway.
# applied to FP32 master params anyway.
...
...
apex/amp/_process_optimizer.py
View file @
b2da92fc
import
types
import
types
from
..fp16_utils
import
master_params_to_model_params
from
..fp16_utils
import
master_params_to_model_params
from
..multi_tensor_apply
import
multi_tensor_applier
from
..multi_tensor_apply
import
multi_tensor_applier
from
._amp_state
import
maybe_print
from
._amp_state
import
maybe_print
,
_amp_state
import
torch
import
torch
from
..optimizers
import
FusedSGD
from
..optimizers
import
FusedSGD
...
@@ -37,7 +37,7 @@ def lazy_init_with_master_weights(self):
...
@@ -37,7 +37,7 @@ def lazy_init_with_master_weights(self):
fp32_from_fp16_params_this_group
=
[]
fp32_from_fp16_params_this_group
=
[]
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
requires_grad
:
if
param
.
requires_grad
:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
{
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
}
:
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
# .format(param.size()))
# .format(param.size()))
fp16_params_this_group
.
append
(
param
)
fp16_params_this_group
.
append
(
param
)
...
@@ -55,8 +55,8 @@ def lazy_init_with_master_weights(self):
...
@@ -55,8 +55,8 @@ def lazy_init_with_master_weights(self):
fp32_params_this_group
.
append
(
param
)
fp32_params_this_group
.
append
(
param
)
param_group
[
'params'
][
i
]
=
param
param_group
[
'params'
][
i
]
=
param
else
:
else
:
raise
TypeError
(
"Optimizer's parameters must
be either
"
raise
TypeError
(
"Optimizer's parameters must
one of
"
"torch.cuda.FloatTensor
or
torch.cuda.
Half
Tensor. "
"torch.cuda.FloatTensor
, torch.cuda.HalfTensor,
torch.cuda.
BFloat16
Tensor. "
"Received {}"
.
format
(
param
.
type
()))
"Received {}"
.
format
(
param
.
type
()))
stash
.
fp16_groups
.
append
(
fp16_params_this_group
)
stash
.
fp16_groups
.
append
(
fp16_params_this_group
)
...
@@ -208,13 +208,13 @@ def lazy_init_no_master_weights(self):
...
@@ -208,13 +208,13 @@ def lazy_init_no_master_weights(self):
stash
.
all_fp32_params
=
[]
stash
.
all_fp32_params
=
[]
for
i
,
param_group
in
enumerate
(
self
.
param_groups
):
for
i
,
param_group
in
enumerate
(
self
.
param_groups
):
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
{
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
}
:
stash
.
all_fp16_params
.
append
(
param
)
stash
.
all_fp16_params
.
append
(
param
)
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
stash
.
all_fp32_params
.
append
(
param
)
stash
.
all_fp32_params
.
append
(
param
)
else
:
else
:
raise
TypeError
(
"Optimizer's parameters must be
either
"
raise
TypeError
(
"Optimizer's parameters must be
one of
"
"torch.cuda.FloatTensor
or
torch.cuda.HalfTensor. "
"torch.cuda.FloatTensor
,
torch.cuda.HalfTensor
, torch.BFloat16Tensor
. "
"Received {}"
.
format
(
param
.
type
()))
"Received {}"
.
format
(
param
.
type
()))
stash
.
all_fp16_grad_stash
=
[
None
for
_
in
stash
.
all_fp16_params
]
stash
.
all_fp16_grad_stash
=
[
None
for
_
in
stash
.
all_fp16_params
]
...
@@ -435,7 +435,7 @@ def _process_optimizer(optimizer, properties):
...
@@ -435,7 +435,7 @@ def _process_optimizer(optimizer, properties):
fp32_from_fp16_params_this_group
=
[]
fp32_from_fp16_params_this_group
=
[]
for
i
,
param
in
enumerate
(
new_group
[
'params'
]):
for
i
,
param
in
enumerate
(
new_group
[
'params'
]):
if
param
.
requires_grad
:
if
param
.
requires_grad
:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
{
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
}
:
fp16_params_this_group
.
append
(
param
)
fp16_params_this_group
.
append
(
param
)
master_param
=
param
.
detach
().
clone
().
float
()
master_param
=
param
.
detach
().
clone
().
float
()
master_param
.
requires_grad
=
True
master_param
.
requires_grad
=
True
...
@@ -445,8 +445,8 @@ def _process_optimizer(optimizer, properties):
...
@@ -445,8 +445,8 @@ def _process_optimizer(optimizer, properties):
fp32_params_this_group
.
append
(
param
)
fp32_params_this_group
.
append
(
param
)
new_group
[
'params'
][
i
]
=
param
new_group
[
'params'
][
i
]
=
param
else
:
else
:
raise
TypeError
(
"Optimizer's parameters must be
either
"
raise
TypeError
(
"Optimizer's parameters must be
one of
"
"torch.cuda.FloatTensor
or
torch.cuda.
Half
Tensor. "
"torch.cuda.FloatTensor
, torch.cuda.HalfTensor,
torch.cuda.
BFloat16
Tensor. "
"Received {}"
.
format
(
param
.
type
()))
"Received {}"
.
format
(
param
.
type
()))
stash
.
fp16_groups
.
append
(
fp16_params_this_group
)
stash
.
fp16_groups
.
append
(
fp16_params_this_group
)
...
@@ -471,15 +471,15 @@ def _process_optimizer(optimizer, properties):
...
@@ -471,15 +471,15 @@ def _process_optimizer(optimizer, properties):
# param.grad = None
# param.grad = None
else
:
else
:
for
param
in
new_group
[
'params'
]:
for
param
in
new_group
[
'params'
]:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
{
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
}
:
stash
.
all_fp16_params
.
append
(
param
)
stash
.
all_fp16_params
.
append
(
param
)
stash
.
all_fp16_grad_stash
.
append
(
None
)
stash
.
all_fp16_grad_stash
.
append
(
None
)
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
stash
.
all_fp32_params
.
append
(
param
)
stash
.
all_fp32_params
.
append
(
param
)
stash
.
all_fp32_grad_stash
.
append
(
None
)
stash
.
all_fp32_grad_stash
.
append
(
None
)
else
:
else
:
raise
TypeError
(
"Optimizer's parameters must
be either
"
raise
TypeError
(
"Optimizer's parameters must
one of
"
"torch.cuda.FloatTensor
or
torch.cuda.
Half
Tensor. "
"torch.cuda.FloatTensor
, torch.cuda.HalfTensor,
torch.cuda.
BFloat16
Tensor. "
"Received {}"
.
format
(
param
.
type
()))
"Received {}"
.
format
(
param
.
type
()))
old_add_param_group
(
new_group
)
old_add_param_group
(
new_group
)
...
...
apex/amp/amp.py
View file @
b2da92fc
...
@@ -9,7 +9,6 @@ import itertools
...
@@ -9,7 +9,6 @@ import itertools
import
torch
import
torch
_DECORATOR_HANDLE
=
None
_DECORATOR_HANDLE
=
None
_USER_CAST_REGISTRY
=
set
()
_USER_CAST_REGISTRY
=
set
()
_USER_PROMOTE_REGISTRY
=
set
()
_USER_PROMOTE_REGISTRY
=
set
()
...
@@ -31,6 +30,9 @@ def half_function(fn):
...
@@ -31,6 +30,9 @@ def half_function(fn):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
True
)
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
True
)
return
_decorator_helper
(
fn
,
utils
.
maybe_half
,
wrap_fn
)
return
_decorator_helper
(
fn
,
utils
.
maybe_half
,
wrap_fn
)
def
bfloat16_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
True
)
return
_decorator_helper
(
fn
,
utils
.
maybe_bfloat16
,
wrap_fn
)
def
float_function
(
fn
):
def
float_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
False
)
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
False
)
...
@@ -49,6 +51,11 @@ def register_half_function(module, name):
...
@@ -49,6 +51,11 @@ def register_half_function(module, name):
name
,
module
))
name
,
module
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_half
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_half
))
def
register_bfloat16_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
raise
ValueError
(
'No function named {} in module {}.'
.
format
(
name
,
module
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_bfloat16
))
def
register_float_function
(
module
,
name
):
def
register_float_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
...
@@ -65,7 +72,7 @@ def register_promote_function(module, name):
...
@@ -65,7 +72,7 @@ def register_promote_function(module, name):
# Top-level function to insert _all_ the hooks.
# Top-level function to insert _all_ the hooks.
def
init
(
enabled
=
True
,
loss_scale
=
"dynamic"
,
enable_caching
=
True
,
verbose
=
False
,
allow_banned
=
False
):
def
init
(
enabled
=
True
,
loss_scale
=
"dynamic"
,
patch_type
=
torch
.
float16
,
enable_caching
=
True
,
verbose
=
False
,
allow_banned
=
False
):
global
_DECORATOR_HANDLE
global
_DECORATOR_HANDLE
if
not
enabled
:
if
not
enabled
:
...
@@ -87,16 +94,30 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
...
@@ -87,16 +94,30 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
wrap
.
promote
(
mod
,
fn
,
handle
,
verbose
)
wrap
.
promote
(
mod
,
fn
,
handle
,
verbose
)
_USER_PROMOTE_REGISTRY
.
clear
()
_USER_PROMOTE_REGISTRY
.
clear
()
# conditionally choose between fp16 and bfloat16 functions list to cache
if
patch_type
==
torch
.
float16
:
low_prec_funcs
=
'FP16_FUNCS'
maybe_low_prec
=
utils
.
maybe_half
low_prec_tensor
=
torch
.
cuda
.
HalfTensor
elif
patch_type
==
torch
.
bfloat16
:
low_prec_funcs
=
'BFLOAT16_FUNCS'
maybe_low_prec
=
utils
.
maybe_bfloat16
low_prec_tensor
=
torch
.
cuda
.
BFloat16Tensor
else
:
raise
RuntimeError
(
"Unsupported patch_torch_functions_type passed to initialize."
+
"Supported types are: torch.float16 and torch.bfloat16."
)
# 1) Force-{fp16, fp32} on white- / black-list functions
# 1) Force-{fp16, fp32} on white- / black-list functions
override_modules
=
[
functional_overrides
,
override_modules
=
[
functional_overrides
,
torch_overrides
,
torch_overrides
,
tensor_overrides
]
tensor_overrides
]
cast_table
=
[(
'FP16_FUNCS'
,
utils
.
maybe_half
),
cast_table
=
[(
low_prec_funcs
,
maybe_low_prec
),
(
'FP32_FUNCS'
,
utils
.
maybe_float
)]
(
'FP32_FUNCS'
,
utils
.
maybe_float
)]
for
module
,
(
list_name
,
cast_fn
)
in
itertools
.
product
(
override_modules
,
for
module
,
(
list_name
,
cast_fn
)
in
itertools
.
product
(
override_modules
,
cast_table
):
cast_table
):
for
fn
in
getattr
(
module
,
list_name
):
for
fn
in
getattr
(
module
,
list_name
):
try_caching
=
(
cast_fn
==
utils
.
maybe_half
)
try_caching
=
(
cast_fn
==
maybe_low_prec
)
wrap
.
cached_cast
(
module
.
MODULE
,
fn
,
cast_fn
,
handle
,
wrap
.
cached_cast
(
module
.
MODULE
,
fn
,
cast_fn
,
handle
,
try_caching
,
verbose
)
try_caching
,
verbose
)
...
@@ -128,12 +149,12 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
...
@@ -128,12 +149,12 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
for
fn
in
getattr
(
tensor_overrides
,
list_name
):
for
fn
in
getattr
(
tensor_overrides
,
list_name
):
promote_fn
(
cls
,
fn
,
handle
,
verbose
)
promote_fn
(
cls
,
fn
,
handle
,
verbose
)
# 3) For any in-place version of a blacklist function, error if any input is fp16.
# 3) For any in-place version of a blacklist function, error if any input is fp16
/bfloat16
.
# NB: this is overly conservative.
# NB: this is overly conservative.
for
fn
in
utils
.
as_inplace
(
torch_overrides
.
FP32_FUNCS
):
for
fn
in
utils
.
as_inplace
(
torch_overrides
.
FP32_FUNCS
):
wrap
.
err_if_any_half
(
torch_overrides
.
MODULE
,
fn
,
handle
)
wrap
.
err_if_any_half
(
torch_overrides
.
MODULE
,
fn
,
handle
)
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
# 3.5) For any in-place blacklist method, error if called on fp16
/bfloat16
tensor
for
fn
in
utils
.
as_inplace
(
tensor_overrides
.
FP32_FUNCS
):
for
fn
in
utils
.
as_inplace
(
tensor_overrides
.
FP32_FUNCS
):
wrap
.
err_if_arg0_half
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
wrap
.
err_if_arg0_half
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
():
...
@@ -141,7 +162,7 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
...
@@ -141,7 +162,7 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
# 4) For other in-place methods, match the type of self tensor
# 4) For other in-place methods, match the type of self tensor
for
fn
in
utils
.
as_inplace
(
itertools
.
chain
(
for
fn
in
utils
.
as_inplace
(
itertools
.
chain
(
tensor_overrides
.
FP16_FUNCS
,
getattr
(
tensor_overrides
,
low_prec_funcs
)
,
tensor_overrides
.
CASTS
)):
tensor_overrides
.
CASTS
)):
wrap
.
promote_match_arg0
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
():
...
@@ -156,10 +177,10 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
...
@@ -156,10 +177,10 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
torch
.
nn
.
modules
.
rnn
.
_VF
=
rnn_compat
.
VariableFunctionsShim
()
torch
.
nn
.
modules
.
rnn
.
_VF
=
rnn_compat
.
VariableFunctionsShim
()
# Wrap all the rnns
# Wrap all the rnns
for
x
in
rnn_compat
.
RNN_NAMES
:
for
x
in
rnn_compat
.
RNN_NAMES
:
wrap
.
new_rnn_cast
(
x
.
upper
(),
handle
,
verbose
)
wrap
.
new_rnn_cast
(
x
.
upper
(),
maybe_low_prec
,
handle
,
verbose
)
# Wrap all the RNN cells
# Wrap all the RNN cells
rnn_compat
.
whitelist_rnn_cells
(
handle
,
verbose
)
rnn_compat
.
whitelist_rnn_cells
(
maybe_low_prec
,
handle
,
verbose
)
# 6) Place error+print message on banned functions.
# 6) Place error+print message on banned functions.
# Or, if allow_banned, then cast to FP32.
# Or, if allow_banned, then cast to FP32.
...
...
apex/amp/compat.py
View file @
b2da92fc
...
@@ -28,7 +28,8 @@ def is_floating_point(x):
...
@@ -28,7 +28,8 @@ def is_floating_point(x):
torch_type
=
x
.
type
()
torch_type
=
x
.
type
()
return
torch_type
.
endswith
(
'FloatTensor'
)
or
\
return
torch_type
.
endswith
(
'FloatTensor'
)
or
\
torch_type
.
endswith
(
'HalfTensor'
)
or
\
torch_type
.
endswith
(
'HalfTensor'
)
or
\
torch_type
.
endswith
(
'DoubleTensor'
)
torch_type
.
endswith
(
'DoubleTensor'
)
or
\
torch_type
.
endswith
(
'BFloat16Tensor'
)
except
AttributeError
:
except
AttributeError
:
return
False
return
False
...
...
apex/amp/frontend.py
View file @
b2da92fc
...
@@ -16,6 +16,10 @@ class Properties(object):
...
@@ -16,6 +16,10 @@ class Properties(object):
"opt_level"
:
None
,
"opt_level"
:
None
,
"cast_model_type"
:
None
,
"cast_model_type"
:
None
,
"patch_torch_functions"
:
False
,
"patch_torch_functions"
:
False
,
# TODO: patch_torch_functions_type could probably be unified with
# patch_torch_functions. Currently introducing a new attribute
# to be on the safer side and not break stuff.
"patch_torch_functions_type"
:
None
,
"keep_batchnorm_fp32"
:
None
,
"keep_batchnorm_fp32"
:
None
,
"master_weights"
:
None
,
"master_weights"
:
None
,
"loss_scale"
:
1.0
,
"loss_scale"
:
1.0
,
...
@@ -53,7 +57,7 @@ class Properties(object):
...
@@ -53,7 +57,7 @@ class Properties(object):
if
name
in
self
.
options
:
if
name
in
self
.
options
:
# print("setting {} {}".format(name, value))
# print("setting {} {}".format(name, value))
if
name
==
"cast_model_type"
:
if
name
==
"cast_model_type"
:
if
self
.
opt_level
==
"O1"
and
value
is
not
None
:
if
self
.
opt_level
in
{
"O1"
,
"O4"
}
and
value
is
not
None
:
if
value
is
not
False
:
if
value
is
not
False
:
if
value
is
not
torch
.
float32
:
if
value
is
not
torch
.
float32
:
warn_or_err
(
"O1 inserts casts around Torch functions rather than "
warn_or_err
(
"O1 inserts casts around Torch functions rather than "
...
@@ -63,13 +67,25 @@ class Properties(object):
...
@@ -63,13 +67,25 @@ class Properties(object):
"cast_model_type was {}"
.
format
(
value
))
"cast_model_type was {}"
.
format
(
value
))
self
.
options
[
name
]
=
value
self
.
options
[
name
]
=
value
elif
name
==
"patch_torch_functions"
:
elif
name
==
"patch_torch_functions"
:
if
self
.
opt_level
!=
"O1"
and
value
:
if
self
.
opt_level
not
in
{
"O1"
,
"O4"
}
and
value
:
warn_or_err
(
"Currently, patch_torch_functions=True should only be set by "
warn_or_err
(
"Currently, patch_torch_functions=True should only be set by "
"selecting opt_level='O1'."
)
"selecting opt_level='O1' or 'O4'."
)
self
.
options
[
name
]
=
value
elif
name
==
"patch_torch_functions_type"
:
if
self
.
opt_level
not
in
{
"O1"
,
"O4"
}
and
value
is
not
None
:
warn_or_err
(
"Currently, patch_torch_functions_type should only be set by "
"selecting opt_level='O1' or 'O4'."
)
elif
self
.
opt_level
==
"O1"
and
value
!=
torch
.
float16
:
warn_or_err
(
"patch_torch_functions_type should only be set to torch.float16 "
"for opt_level='O1."
)
elif
self
.
opt_level
==
"O4"
and
value
!=
torch
.
bfloat16
:
warn_or_err
(
"patch_torch_functions_type should only be set to torch.bfloat16 "
"for opt_level='O4."
)
else
:
self
.
options
[
name
]
=
value
self
.
options
[
name
]
=
value
elif
name
==
"keep_batchnorm_fp32"
:
elif
name
==
"keep_batchnorm_fp32"
:
if
self
.
opt_level
==
"O1"
and
value
is
not
None
:
if
self
.
opt_level
in
{
"O1"
,
"O4"
}
and
value
is
not
None
:
warn_or_err
(
"With opt_level O1, batchnorm functions are automatically patched "
warn_or_err
(
"With opt_level O1
or O4
, batchnorm functions are automatically patched "
"to run in FP32, so keep_batchnorm_fp32 should be None."
+
"to run in FP32, so keep_batchnorm_fp32 should be None."
+
" keep_batchnorm_fp32 was {}"
.
format
(
value
))
" keep_batchnorm_fp32 was {}"
.
format
(
value
))
if
value
==
"False"
:
if
value
==
"False"
:
...
@@ -82,9 +98,9 @@ class Properties(object):
...
@@ -82,9 +98,9 @@ class Properties(object):
"or None, found keep_batchnorm_fp32={}"
.
format
(
value
)
"or None, found keep_batchnorm_fp32={}"
.
format
(
value
)
self
.
options
[
name
]
=
value
self
.
options
[
name
]
=
value
elif
name
==
"master_weights"
:
elif
name
==
"master_weights"
:
if
self
.
opt_level
==
"O1"
and
value
is
not
None
:
if
self
.
opt_level
in
{
"O1"
,
"O4"
}
and
value
is
not
None
:
warn_or_err
(
"It doesn't make sense to use master_weights with O1. "
warn_or_err
(
"It doesn't make sense to use master_weights with O1
and O4
. "
"With O1, your model weights themselves should be FP32."
)
"With O1
and O4
, your model weights themselves should be FP32."
)
self
.
options
[
name
]
=
value
self
.
options
[
name
]
=
value
elif
name
==
"loss_scale"
:
elif
name
==
"loss_scale"
:
if
value
==
"dynamic"
:
if
value
==
"dynamic"
:
...
@@ -113,6 +129,7 @@ class O3:
...
@@ -113,6 +129,7 @@ class O3:
properties
.
opt_level
=
"O3"
properties
.
opt_level
=
"O3"
properties
.
cast_model_type
=
torch
.
float16
properties
.
cast_model_type
=
torch
.
float16
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions_type
=
None
properties
.
keep_batchnorm_fp32
=
False
properties
.
keep_batchnorm_fp32
=
False
properties
.
master_weights
=
False
properties
.
master_weights
=
False
properties
.
loss_scale
=
1.0
properties
.
loss_scale
=
1.0
...
@@ -136,6 +153,7 @@ class O2:
...
@@ -136,6 +153,7 @@ class O2:
properties
.
opt_level
=
"O2"
properties
.
opt_level
=
"O2"
properties
.
cast_model_type
=
torch
.
float16
properties
.
cast_model_type
=
torch
.
float16
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions_type
=
None
properties
.
keep_batchnorm_fp32
=
True
properties
.
keep_batchnorm_fp32
=
True
properties
.
master_weights
=
True
properties
.
master_weights
=
True
properties
.
loss_scale
=
"dynamic"
properties
.
loss_scale
=
"dynamic"
...
@@ -158,6 +176,7 @@ class O1:
...
@@ -158,6 +176,7 @@ class O1:
properties
.
opt_level
=
"O1"
properties
.
opt_level
=
"O1"
properties
.
cast_model_type
=
None
properties
.
cast_model_type
=
None
properties
.
patch_torch_functions
=
True
properties
.
patch_torch_functions
=
True
properties
.
patch_torch_functions_type
=
torch
.
float16
properties
.
keep_batchnorm_fp32
=
None
properties
.
keep_batchnorm_fp32
=
None
properties
.
master_weights
=
None
properties
.
master_weights
=
None
properties
.
loss_scale
=
"dynamic"
properties
.
loss_scale
=
"dynamic"
...
@@ -177,6 +196,7 @@ class O0:
...
@@ -177,6 +196,7 @@ class O0:
properties
.
opt_level
=
"O0"
properties
.
opt_level
=
"O0"
properties
.
cast_model_type
=
torch
.
float32
properties
.
cast_model_type
=
torch
.
float32
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions_type
=
None
properties
.
keep_batchnorm_fp32
=
None
properties
.
keep_batchnorm_fp32
=
None
properties
.
master_weights
=
False
properties
.
master_weights
=
False
properties
.
loss_scale
=
1.0
properties
.
loss_scale
=
1.0
...
@@ -184,11 +204,54 @@ class O0:
...
@@ -184,11 +204,54 @@ class O0:
# properties.enable_ddp_interop = False
# properties.enable_ddp_interop = False
return
properties
# modified in place so this isn't really necessary
return
properties
# modified in place so this isn't really necessary
class
O4
:
brief
=
"O4: Insert automatic casts around Pytorch functions and Tensor methods.
\n
"
more
=
"The type of your model's weights is not altered. However, internally,
\n
"
\
"Pytorch functions are patched to cast any Tensor Core-friendly ops to BFLOAT16 for speed,
\n
"
\
"while operations that might benefit from the additional stability of FP32 are patched
\n
"
\
"to cast their inputs to fp32.
\n
"
\
"Loss scaling is not required in O4 mode since bflaot16 has the same dynamic range as fp32."
def
__call__
(
self
,
properties
):
properties
.
enabled
=
True
properties
.
opt_level
=
"O4"
properties
.
cast_model_type
=
None
properties
.
patch_torch_functions
=
True
properties
.
patch_torch_functions_type
=
torch
.
bfloat16
properties
.
keep_batchnorm_fp32
=
None
properties
.
master_weights
=
None
properties
.
loss_scale
=
1
return
properties
# modified in place so this isn't really necessary
class
O5
:
brief
=
"O5: BFLOAT16 training with FP32 batchnorm and FP32 master weights.
\n
"
more
=
"Calls .bfloat16() on your model, converting the entire model (except for batchnorms)
\n
"
\
"to BFLOAT16. Batchnorms are retained in FP32 for additional stability.
\n
"
\
"The forward pass is patched to cast incoming Tensors to BFLOAT16, so you don't need to change
\n
"
\
"your data pipeline.
\n
"
\
"O5 creates FP32 master weights outside the model and patches any optimizers to update
\n
"
\
"these master weights, then copy the master weights into the BFLOAT16 model weights.
\n
"
\
"Master weights can also improve convergence and stability."
def
__call__
(
self
,
properties
):
properties
.
enabled
=
True
properties
.
opt_level
=
"O5"
properties
.
cast_model_type
=
torch
.
bfloat16
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions
=
None
properties
.
patch_torch_functions_type
=
None
properties
.
keep_batchnorm_fp32
=
True
properties
.
master_weights
=
True
properties
.
loss_scale
=
1
return
properties
# modified in place so this isn't really necessary
opt_levels
=
{
"O3"
:
O3
(),
opt_levels
=
{
"O3"
:
O3
(),
"O2"
:
O2
(),
"O2"
:
O2
(),
"O1"
:
O1
(),
"O1"
:
O1
(),
"O0"
:
O0
()}
"O0"
:
O0
(),
"O4"
:
O4
(),
"O5"
:
O5
()}
# allow user to directly pass Properties struct as well?
# allow user to directly pass Properties struct as well?
...
@@ -199,6 +262,7 @@ def initialize(
...
@@ -199,6 +262,7 @@ def initialize(
opt_level
=
"O1"
,
opt_level
=
"O1"
,
cast_model_type
=
None
,
cast_model_type
=
None
,
patch_torch_functions
=
None
,
patch_torch_functions
=
None
,
patch_torch_functions_type
=
None
,
keep_batchnorm_fp32
=
None
,
keep_batchnorm_fp32
=
None
,
master_weights
=
None
,
master_weights
=
None
,
loss_scale
=
None
,
loss_scale
=
None
,
...
@@ -235,10 +299,11 @@ def initialize(
...
@@ -235,10 +299,11 @@ def initialize(
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
should run as if Amp were not present.
should run as if Amp were not present.
opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
"O0", "O1", "O2", and "O
3
", explained in detail above.
"O0", "O1", "O2",
"O3", "O4"
and "O
5
", explained in detail above.
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
above.
above.
patch_torch_functions (bool, optional, default=None): Optional property override.
patch_torch_functions (bool, optional, default=None): Optional property override.
patch_torch_functions_type (``torch.dtype``, optional, default=None): Optional property override
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
passed as a string, must be the string "True" or "False".
passed as a string, must be the string "True" or "False".
master_weights (bool, optional, default=None): Optional property override.
master_weights (bool, optional, default=None): Optional property override.
...
@@ -321,14 +386,14 @@ def initialize(
...
@@ -321,14 +386,14 @@ def initialize(
if
opt_level
not
in
opt_levels
:
if
opt_level
not
in
opt_levels
:
raise
RuntimeError
(
raise
RuntimeError
(
"Unexpected optimization level {}. "
.
format
(
opt_level
)
+
"Unexpected optimization level {}. "
.
format
(
opt_level
)
+
"Options are 'O0', 'O1', 'O2', 'O3'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, "
+
"Options are 'O0', 'O1', 'O2', 'O3'
, 'O4', 'O5'
. Note that in `O0`, `O1`, etc., the prefix O is the letter O, "
+
"not the number zero."
)
"not the number zero."
)
else
:
else
:
_amp_state
.
opt_properties
=
opt_levels
[
opt_level
](
_amp_state
.
opt_properties
)
_amp_state
.
opt_properties
=
opt_levels
[
opt_level
](
_amp_state
.
opt_properties
)
maybe_print
(
"Selected optimization level {}"
.
format
(
opt_levels
[
opt_level
].
brief
),
True
)
maybe_print
(
"Selected optimization level {}"
.
format
(
opt_levels
[
opt_level
].
brief
),
True
)
maybe_print
(
"Defaults for this optimization level are:"
,
True
)
maybe_print
(
"Defaults for this optimization level are:"
,
True
)
for
k
,
v
in
_amp_state
.
opt_properties
.
options
.
items
():
for
k
,
v
in
_amp_state
.
opt_properties
.
options
.
items
():
maybe_print
(
"{:2
2
} : {}"
.
format
(
k
,
v
),
True
)
maybe_print
(
"{:2
6
} : {}"
.
format
(
k
,
v
),
True
)
_amp_state
.
min_loss_scale
=
min_loss_scale
_amp_state
.
min_loss_scale
=
min_loss_scale
_amp_state
.
max_loss_scale
=
max_loss_scale
_amp_state
.
max_loss_scale
=
max_loss_scale
...
@@ -344,6 +409,8 @@ def initialize(
...
@@ -344,6 +409,8 @@ def initialize(
_amp_state
.
opt_properties
.
cast_model_type
=
cast_model_type
_amp_state
.
opt_properties
.
cast_model_type
=
cast_model_type
if
patch_torch_functions
is
not
None
:
if
patch_torch_functions
is
not
None
:
_amp_state
.
opt_properties
.
patch_torch_functions
=
patch_torch_functions
_amp_state
.
opt_properties
.
patch_torch_functions
=
patch_torch_functions
if
patch_torch_functions_type
is
not
None
:
_amp_state
.
opt_properties
.
patch_torch_functions_type
=
patch_torch_functions_type
if
keep_batchnorm_fp32
is
not
None
:
if
keep_batchnorm_fp32
is
not
None
:
_amp_state
.
opt_properties
.
keep_batchnorm_fp32
=
keep_batchnorm_fp32
_amp_state
.
opt_properties
.
keep_batchnorm_fp32
=
keep_batchnorm_fp32
if
master_weights
is
not
None
:
if
master_weights
is
not
None
:
...
@@ -353,7 +420,7 @@ def initialize(
...
@@ -353,7 +420,7 @@ def initialize(
maybe_print
(
"After processing overrides, optimization options are:"
,
True
)
maybe_print
(
"After processing overrides, optimization options are:"
,
True
)
for
k
,
v
in
_amp_state
.
opt_properties
.
options
.
items
():
for
k
,
v
in
_amp_state
.
opt_properties
.
options
.
items
():
maybe_print
(
"{:2
2
} : {}"
.
format
(
k
,
v
),
True
)
maybe_print
(
"{:2
6
} : {}"
.
format
(
k
,
v
),
True
)
return
_initialize
(
models
,
optimizers
,
_amp_state
.
opt_properties
,
num_losses
,
cast_model_outputs
)
return
_initialize
(
models
,
optimizers
,
_amp_state
.
opt_properties
,
num_losses
,
cast_model_outputs
)
...
...
apex/amp/lists/functional_overrides.py
View file @
b2da92fc
...
@@ -26,6 +26,17 @@ FP16_FUNCS = [
...
@@ -26,6 +26,17 @@ FP16_FUNCS = [
'linear'
,
'linear'
,
]
]
BFLOAT16_FUNCS
=
[
'conv1d'
,
'conv2d'
,
'conv3d'
,
'conv_transpose1d'
,
'conv_transpose2d'
,
'conv_transpose3d'
,
'conv_tbc'
,
# Undocumented / maybe new?
'linear'
,
]
FP32_FUNCS
=
[
FP32_FUNCS
=
[
# Interpolation/Upsampling TODO: Remove for 1.2
# Interpolation/Upsampling TODO: Remove for 1.2
...
...
apex/amp/lists/tensor_overrides.py
View file @
b2da92fc
...
@@ -15,6 +15,10 @@ FP16_FUNCS = [
...
@@ -15,6 +15,10 @@ FP16_FUNCS = [
'__matmul__'
,
'__matmul__'
,
]
]
BFLOAT16_FUNCS
=
[
'__matmul__'
,
]
FP32_FUNCS
=
[
FP32_FUNCS
=
[
'__ipow__'
,
'__ipow__'
,
'__pow__'
,
'__pow__'
,
...
@@ -56,7 +60,7 @@ SEQUENCE_CASTS = []
...
@@ -56,7 +60,7 @@ SEQUENCE_CASTS = []
# between `torch` and `torch.Tensor` (and check with `hasattr`,
# between `torch` and `torch.Tensor` (and check with `hasattr`,
# because a few random ones aren't defined on Tensor)
# because a few random ones aren't defined on Tensor)
_self_mod
=
importlib
.
import_module
(
__name__
)
_self_mod
=
importlib
.
import_module
(
__name__
)
for
attrname
in
[
'FP16_FUNCS'
,
'FP32_FUNCS'
,
'CASTS'
,
'SEQUENCE_CASTS'
]:
for
attrname
in
[
'FP16_FUNCS'
,
'BFLOAT16_FUNCS'
,
'FP32_FUNCS'
,
'CASTS'
,
'SEQUENCE_CASTS'
]:
lst
=
getattr
(
_self_mod
,
attrname
)
lst
=
getattr
(
_self_mod
,
attrname
)
for
fn
in
getattr
(
torch_overrides
,
attrname
):
for
fn
in
getattr
(
torch_overrides
,
attrname
):
if
hasattr
(
MODULE
,
fn
):
if
hasattr
(
MODULE
,
fn
):
...
...
apex/amp/lists/torch_overrides.py
View file @
b2da92fc
...
@@ -26,6 +26,27 @@ FP16_FUNCS = [
...
@@ -26,6 +26,27 @@ FP16_FUNCS = [
'mv'
,
'mv'
,
]
]
BFLOAT16_FUNCS
=
[
# Low level functions wrapped by torch.nn layers.
# The wrapper layers contain the weights which are then passed in as a parameter
# to these functions.
'conv1d'
,
'conv2d'
,
'conv3d'
,
'conv_transpose1d'
,
'conv_transpose2d'
,
'conv_transpose3d'
,
'conv_tbc'
,
# BLAS
'addmm'
,
'addmv'
,
'addr'
,
'matmul'
,
'mm'
,
'mv'
,
]
FP32_FUNCS
=
[
FP32_FUNCS
=
[
# Pointwise
# Pointwise
'acos'
,
'acos'
,
...
...
apex/amp/rnn_compat.py
View file @
b2da92fc
...
@@ -28,7 +28,7 @@ def has_old_rnns():
...
@@ -28,7 +28,7 @@ def has_old_rnns():
except
:
except
:
return
False
return
False
def
whitelist_rnn_cells
(
handle
,
verbose
):
def
whitelist_rnn_cells
(
cast_fn
,
handle
,
verbose
):
# Different module + function names in old/new RNN cases
# Different module + function names in old/new RNN cases
if
has_old_rnns
():
if
has_old_rnns
():
fn_names
=
[
'RNNReLUCell'
,
'RNNTanhCell'
,
'LSTMCell'
,
'GRUCell'
]
fn_names
=
[
'RNNReLUCell'
,
'RNNTanhCell'
,
'LSTMCell'
,
'GRUCell'
]
...
@@ -40,7 +40,7 @@ def whitelist_rnn_cells(handle, verbose):
...
@@ -40,7 +40,7 @@ def whitelist_rnn_cells(handle, verbose):
# Insert casts on cell functions
# Insert casts on cell functions
for
fn
in
fn_names
:
for
fn
in
fn_names
:
wrap
.
cached_cast
(
mod
,
fn
,
utils
.
maybe_half
,
handle
,
wrap
.
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
try_caching
=
True
,
verbose
=
verbose
)
try_caching
=
True
,
verbose
=
verbose
)
if
has_old_rnns
():
if
has_old_rnns
():
...
...
apex/amp/utils.py
View file @
b2da92fc
...
@@ -62,6 +62,17 @@ def maybe_half(x, name='', verbose=False):
...
@@ -62,6 +62,17 @@ def maybe_half(x, name='', verbose=False):
print
(
'Float->Half ({})'
.
format
(
name
))
print
(
'Float->Half ({})'
.
format
(
name
))
return
x
.
half
()
return
x
.
half
()
def
maybe_bfloat16
(
x
,
name
=
''
,
verbose
=
False
):
if
is_nested
(
x
):
return
type
(
x
)([
maybe_bfloat16
(
y
)
for
y
in
x
])
if
not
x
.
is_cuda
or
type_string
(
x
)
==
'BFloat16Tensor'
:
return
x
else
:
if
verbose
:
print
(
'Float->BFloat16 ({})'
.
format
(
name
))
return
x
.
bfloat16
()
def
maybe_float
(
x
,
name
=
''
,
verbose
=
False
):
def
maybe_float
(
x
,
name
=
''
,
verbose
=
False
):
if
is_nested
(
x
):
if
is_nested
(
x
):
return
type
(
x
)([
maybe_float
(
y
)
for
y
in
x
])
return
type
(
x
)([
maybe_float
(
y
)
for
y
in
x
])
...
@@ -189,22 +200,28 @@ def synthesize_flattened_rnn_weights(fp32_weights,
...
@@ -189,22 +200,28 @@ def synthesize_flattened_rnn_weights(fp32_weights,
fp16_weights
.
append
(
fp16_layer_weights
)
fp16_weights
.
append
(
fp16_layer_weights
)
return
fp16_weights
return
fp16_weights
def
_str_from_dtype
(
dtype
=
torch
.
float16
):
type_to_str
=
{
torch
.
float16
:
'Half'
,
torch
.
bfloat16
:
'BFloat16'
}
return
type_to_str
[
dtype
]
# Roughly same as above, just the `fp32_weights` aren't nested.
# Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability.
# Code kept separate for readability.
def
new_synthesize_flattened_rnn_weights
(
fp32_weights
,
def
new_synthesize_flattened_rnn_weights
(
fp32_weights
,
fp16_flat_tensor
,
fp16_flat_tensor
,
rnn_fn
=
''
,
rnn_fn
=
''
,
dtype
=
torch
.
float16
,
verbose
=
False
):
verbose
=
False
):
fp16_weights
=
[]
fp16_weights
=
[]
fp32_base_ptr
=
fp32_weights
[
0
].
data_ptr
()
fp32_base_ptr
=
fp32_weights
[
0
].
data_ptr
()
for
w_fp32
in
fp32_weights
:
for
w_fp32
in
fp32_weights
:
w_fp16
=
w_fp32
.
new
().
half
(
)
w_fp16
=
w_fp32
.
new
().
to
(
dtype
=
dtype
)
offset
=
(
w_fp32
.
data_ptr
()
-
fp32_base_ptr
)
//
w_fp32
.
element_size
()
offset
=
(
w_fp32
.
data_ptr
()
-
fp32_base_ptr
)
//
w_fp32
.
element_size
()
w_fp16
.
set_
(
fp16_flat_tensor
.
storage
(),
w_fp16
.
set_
(
fp16_flat_tensor
.
storage
(),
offset
,
offset
,
w_fp32
.
shape
)
w_fp32
.
shape
)
w_fp16
.
copy_
(
w_fp32
)
w_fp16
.
copy_
(
w_fp32
)
if
verbose
:
if
verbose
:
print
(
'Float->
Half
({})'
.
format
(
rnn_fn
))
print
(
'Float->
{}
({})'
.
format
(
_str_from_dtype
(
dtype
),
rnn_fn
))
fp16_weights
.
append
(
w_fp16
)
fp16_weights
.
append
(
w_fp16
)
return
fp16_weights
return
fp16_weights
apex/amp/wrap.py
View file @
b2da92fc
...
@@ -51,7 +51,8 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
...
@@ -51,7 +51,8 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
if
len
(
types
)
<=
1
:
if
len
(
types
)
<=
1
:
return
orig_fn
(
*
args
,
**
kwargs
)
return
orig_fn
(
*
args
,
**
kwargs
)
elif
len
(
types
)
==
2
and
types
==
set
([
'HalfTensor'
,
'FloatTensor'
]):
elif
len
(
types
)
==
2
and
(
types
==
set
([
'HalfTensor'
,
'FloatTensor'
])
or
types
==
set
([
'BFloat16Tensor'
,
'FloatTensor'
])):
new_args
=
utils
.
casted_args
(
cast_fn
,
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
args
,
kwargs
)
kwargs
)
...
@@ -79,7 +80,8 @@ def sequence_promote(mod, fn, handle, verbose=False):
...
@@ -79,7 +80,8 @@ def sequence_promote(mod, fn, handle, verbose=False):
types
=
set
([
utils
.
type_string
(
x
)
for
x
in
seq
])
types
=
set
([
utils
.
type_string
(
x
)
for
x
in
seq
])
if
len
(
types
)
<=
1
:
if
len
(
types
)
<=
1
:
return
orig_fn
(
seq
,
*
args
,
**
kwargs
)
return
orig_fn
(
seq
,
*
args
,
**
kwargs
)
elif
types
==
set
([
'HalfTensor'
,
'FloatTensor'
]):
elif
(
types
==
set
([
'HalfTensor'
,
'FloatTensor'
])
or
types
==
set
([
'BFloat16Tensor'
,
'FloatTensor'
])):
cast_seq
=
utils
.
casted_args
(
maybe_float
,
cast_seq
=
utils
.
casted_args
(
maybe_float
,
seq
,
{})
seq
,
{})
return
orig_fn
(
cast_seq
,
*
args
,
**
kwargs
)
return
orig_fn
(
cast_seq
,
*
args
,
**
kwargs
)
...
@@ -102,6 +104,8 @@ def promote_match_arg0(mod, fn, handle, verbose=False):
...
@@ -102,6 +104,8 @@ def promote_match_arg0(mod, fn, handle, verbose=False):
if
utils
.
type_string
(
arg0
)
==
'HalfTensor'
:
if
utils
.
type_string
(
arg0
)
==
'HalfTensor'
:
cast_fn
=
utils
.
maybe_half
cast_fn
=
utils
.
maybe_half
if
utils
.
type_string
(
arg0
)
==
'BFloat16Tensor'
:
cast_fn
=
utils
.
maybe_bfloat16
elif
utils
.
type_string
(
arg0
)
==
'FloatTensor'
:
elif
utils
.
type_string
(
arg0
)
==
'FloatTensor'
:
cast_fn
=
utils
.
maybe_float
cast_fn
=
utils
.
maybe_float
else
:
else
:
...
@@ -119,12 +123,12 @@ def err_if_any_half(mod, fn, handle, custom_err_msg=None):
...
@@ -119,12 +123,12 @@ def err_if_any_half(mod, fn, handle, custom_err_msg=None):
@
functools
.
wraps
(
orig_fn
)
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
types
=
utils
.
collect_fp_tensor_types
(
args
,
kwargs
)
types
=
utils
.
collect_fp_tensor_types
(
args
,
kwargs
)
if
'HalfTensor'
in
types
:
if
'HalfTensor'
in
types
or
'BFloat16Tensor'
in
types
:
if
custom_err_msg
:
if
custom_err_msg
:
raise
NotImplementedError
(
custom_err_msg
)
raise
NotImplementedError
(
custom_err_msg
)
else
:
else
:
raise
NotImplementedError
(
'Cannot call in-place function '
+
raise
NotImplementedError
(
'Cannot call in-place function '
+
'{} with fp16
argument
s.'
.
format
(
fn
))
'{} with fp16
or bfloat16 arg
s.'
.
format
(
fn
))
else
:
else
:
return
orig_fn
(
*
args
,
**
kwargs
)
return
orig_fn
(
*
args
,
**
kwargs
)
utils
.
set_func_save
(
handle
,
mod
,
fn
,
wrapper
)
utils
.
set_func_save
(
handle
,
mod
,
fn
,
wrapper
)
...
@@ -137,9 +141,9 @@ def err_if_arg0_half(mod, fn, handle, verbose=False):
...
@@ -137,9 +141,9 @@ def err_if_arg0_half(mod, fn, handle, verbose=False):
@
functools
.
wraps
(
orig_fn
)
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
arg0
,
*
args
,
**
kwargs
):
def
wrapper
(
arg0
,
*
args
,
**
kwargs
):
assert
compat
.
is_tensor_like
(
arg0
)
assert
compat
.
is_tensor_like
(
arg0
)
if
utils
.
type_string
(
arg0
)
==
'HalfTensor'
:
if
utils
.
type_string
(
arg0
)
in
{
'HalfTensor'
,
'BFloat16Tensor'
}
:
raise
NotImplementedError
(
'Cannot call in-place method '
+
raise
NotImplementedError
(
'Cannot call in-place method '
+
'{}
on
fp16
Tensor
s.'
.
format
(
fn
))
'{}
with
fp16
or bfloat16 arg
s.'
.
format
(
fn
))
else
:
else
:
cast_fn
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
cast_fn
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
kwargs
)
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
kwargs
)
...
@@ -219,7 +223,7 @@ def rnn_cast(backend, fn, handle, verbose=False):
...
@@ -219,7 +223,7 @@ def rnn_cast(backend, fn, handle, verbose=False):
return
fwd_wrapper
return
fwd_wrapper
utils
.
set_func_save
(
handle
,
backend
,
fn
,
rnn_wrapper
)
utils
.
set_func_save
(
handle
,
backend
,
fn
,
rnn_wrapper
)
def
new_rnn_cast
(
fn
,
handle
,
verbose
=
False
):
def
new_rnn_cast
(
fn
,
cast_fn
,
handle
,
verbose
=
False
):
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# that _rnn_impls stashed. For rnn backend calls that directly invoke
# that _rnn_impls stashed. For rnn backend calls that directly invoke
...
@@ -232,7 +236,7 @@ def new_rnn_cast(fn, handle, verbose=False):
...
@@ -232,7 +236,7 @@ def new_rnn_cast(fn, handle, verbose=False):
assert
isinstance
(
mod
,
rnn_compat
.
VariableFunctionsShim
)
assert
isinstance
(
mod
,
rnn_compat
.
VariableFunctionsShim
)
fn
=
fn
.
lower
()
fn
=
fn
.
lower
()
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
cast_fn
=
utils
.
verbosify
(
utils
.
maybe_half
,
fn
,
verbose
)
cast_fn
=
utils
.
verbosify
(
cast_fn
,
fn
,
verbose
)
@
functools
.
wraps
(
orig_fn
)
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
# Exact call signature from modules/rnn.py
# Exact call signature from modules/rnn.py
...
@@ -247,14 +251,20 @@ def new_rnn_cast(fn, handle, verbose=False):
...
@@ -247,14 +251,20 @@ def new_rnn_cast(fn, handle, verbose=False):
else
:
else
:
params_idx
=
3
# PackedSequence case
params_idx
=
3
# PackedSequence case
if
cast_fn
==
utils
.
maybe_half
:
dtype
=
torch
.
half
elif
cast_fn
==
utils
.
maybe_bfloat16
:
dtype
=
torch
.
bfloat16
else
:
raise
RuntimeError
(
"Unsupported cast_fn passed. Supports only maybe_half and maybe_bfloat16"
)
new_args
=
[]
new_args
=
[]
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
if
i
==
params_idx
:
if
i
==
params_idx
:
num_params
=
sum
([
x
.
numel
()
for
x
in
arg
])
num_params
=
sum
([
x
.
numel
()
for
x
in
arg
])
fp16_weight_buf
=
args
[
0
].
new_empty
((
num_params
,),
fp16_weight_buf
=
args
[
0
].
new_empty
((
num_params
,),
dtype
=
torch
.
half
)
dtype
=
dtype
)
casted_weights
=
utils
.
new_synthesize_flattened_rnn_weights
(
casted_weights
=
utils
.
new_synthesize_flattened_rnn_weights
(
arg
,
fp16_weight_buf
,
fn
,
verbose
)
arg
,
fp16_weight_buf
,
fn
,
dtype
,
verbose
)
new_args
.
append
(
casted_weights
)
new_args
.
append
(
casted_weights
)
elif
utils
.
is_fp_tensor
(
arg
):
elif
utils
.
is_fp_tensor
(
arg
):
new_args
.
append
(
cast_fn
(
arg
))
new_args
.
append
(
cast_fn
(
arg
))
...
...
csrc/multi_tensor_adam.cu
View file @
b2da92fc
...
@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda(
...
@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda(
}
}
// Assume single type across p,g,m1,m2 now
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
DISPATCH_DOUBLE_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adam"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adam"
,
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
...
...
csrc/multi_tensor_axpby_kernel.cu
View file @
b2da92fc
...
@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda(
...
@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda(
// If build times suffer, think about where to put this dispatch,
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
2
][
0
].
scalar_type
(),
2
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
2
][
0
].
scalar_type
(),
2
,
"multi_tensor_axpby_cuda"
,
multi_tensor_apply
<
3
>
(
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_l2norm_kernel.cu
View file @
b2da92fc
...
@@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
...
@@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
ret_per_tensor
=
at
::
empty
({
0
},
float_options
);
ret_per_tensor
=
at
::
empty
({
0
},
float_options
);
}
}
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
multi_tensor_apply
<
1
>
(
multi_tensor_apply
<
1
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
@@ -391,7 +391,7 @@ void multi_tensor_norm_out_cuda(
...
@@ -391,7 +391,7 @@ void multi_tensor_norm_out_cuda(
output_per_tensor
=
at
::
zeros
({
ntensors
*
max_chunks_per_tensor
},
float_options
);
output_per_tensor
=
at
::
zeros
({
ntensors
*
max_chunks_per_tensor
},
float_options
);
if
(
norm_type
==
0
)
{
if
(
norm_type
==
0
)
{
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_maxnorm_cuda"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_maxnorm_cuda"
,
multi_tensor_apply
<
1
>
(
multi_tensor_apply
<
1
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
...
@@ -405,7 +405,7 @@ void multi_tensor_norm_out_cuda(
...
@@ -405,7 +405,7 @@ void multi_tensor_norm_out_cuda(
max_chunks_per_tensor
);)
max_chunks_per_tensor
);)
}
}
else
{
else
{
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
multi_tensor_apply
<
1
>
(
multi_tensor_apply
<
1
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
...
...
csrc/multi_tensor_lamb.cu
View file @
b2da92fc
...
@@ -363,7 +363,7 @@ void multi_tensor_lamb_cuda(
...
@@ -363,7 +363,7 @@ void multi_tensor_lamb_cuda(
// We now in-place modify grad to store update before compute its norm
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time
// Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
@@ -386,7 +386,7 @@ void multi_tensor_lamb_cuda(
...
@@ -386,7 +386,7 @@ void multi_tensor_lamb_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_param_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
2
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_param_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
2
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
multi_tensor_apply
<
2
>
(
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_lamb_stage_1.cu
View file @
b2da92fc
...
@@ -127,9 +127,9 @@ void multi_tensor_lamb_stage1_cuda(
...
@@ -127,9 +127,9 @@ void multi_tensor_lamb_stage1_cuda(
float
next_step
=
float
(
step
+
1
);
float
next_step
=
float
(
step
+
1
);
float
beta1_correction
=
1.0
f
-
std
::
pow
(
beta1
,
next_step
);
float
beta1_correction
=
1.0
f
-
std
::
pow
(
beta1
,
next_step
);
float
beta2_correction
=
1.0
f
-
std
::
pow
(
beta2
,
next_step
);
float
beta2_correction
=
1.0
f
-
std
::
pow
(
beta2
,
next_step
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
4
][
0
].
scalar_type
(),
2
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
4
][
0
].
scalar_type
(),
2
,
"lamb_stage_1"
,
multi_tensor_apply
<
5
>
(
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_lamb_stage_2.cu
View file @
b2da92fc
...
@@ -91,8 +91,8 @@ void multi_tensor_lamb_stage2_cuda(
...
@@ -91,8 +91,8 @@ void multi_tensor_lamb_stage2_cuda(
{
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_2"
,
multi_tensor_apply
<
2
>
(
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_novograd.cu
View file @
b2da92fc
...
@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda(
...
@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda(
multi_tensor_norm_out_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
grad_norms
,
beta2
,
(
1.0
f
-
beta2
),
norm_type
);
multi_tensor_norm_out_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
grad_norms
,
beta2
,
(
1.0
f
-
beta2
),
norm_type
);
// Assume single type across p,g,m1,m2 now
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
DISPATCH_DOUBLE_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"novograd"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"novograd"
,
multi_tensor_apply
<
3
>
(
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
...
...
csrc/multi_tensor_scale_kernel.cu
View file @
b2da92fc
...
@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda(
...
@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch,
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_scale_cuda"
,
multi_tensor_apply
<
2
>
(
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment