Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c7fd532c
Commit
c7fd532c
authored
May 08, 2020
by
rohithkrn
Browse files
basic enablement for O4 and O5 opt levels
parent
8124df13
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
170 additions
and
41 deletions
+170
-41
apex/amp/_initialize.py
apex/amp/_initialize.py
+7
-5
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+13
-13
apex/amp/amp.py
apex/amp/amp.py
+23
-10
apex/amp/compat.py
apex/amp/compat.py
+2
-1
apex/amp/frontend.py
apex/amp/frontend.py
+75
-11
apex/amp/lists/functional_overrides.py
apex/amp/lists/functional_overrides.py
+11
-0
apex/amp/lists/tensor_overrides.py
apex/amp/lists/tensor_overrides.py
+5
-1
apex/amp/lists/torch_overrides.py
apex/amp/lists/torch_overrides.py
+21
-0
apex/amp/utils.py
apex/amp/utils.py
+11
-0
apex/amp/wrap.py
apex/amp/wrap.py
+2
-0
No files found.
apex/amp/_initialize.py
View file @
c7fd532c
...
...
@@ -80,10 +80,10 @@ def check_params_fp32(models):
for
model
in
models
:
for
name
,
param
in
model
.
named_parameters
():
if
param
.
is_floating_point
():
if
'Half'
in
param
.
type
():
if
'Half'
in
param
.
type
()
or
'BFloat16'
in
param
.
type
():
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.initialize, you do not need to call .half() o
n your model
\n
"
"before passing it, no matter what optimization level you choose."
.
format
(
"When using amp.initialize, you do not need to call .half() o
r .bfloat16()
\n
"
"
on your model
before passing it, no matter what optimization level you choose."
.
format
(
name
,
param
.
type
()))
elif
not
param
.
is_cuda
:
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
...
...
@@ -137,7 +137,7 @@ class O2StateDictHook(object):
def
__call__
(
self
,
module
,
state_dict
,
prefix
,
local_metadata
):
for
key
in
state_dict
:
param
=
state_dict
[
key
]
if
'Half'
in
param
.
type
():
if
'Half'
in
param
.
type
()
or
'BFloat16'
in
param
.
type
():
param
=
param
.
to
(
torch
.
float32
)
state_dict
[
key
]
=
param
...
...
@@ -232,7 +232,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if
properties
.
patch_torch_functions
:
# handle is unused here. It's accessible later through a global value anyway.
handle
=
amp_init
(
loss_scale
=
properties
.
loss_scale
,
verbose
=
(
_amp_state
.
verbosity
==
2
))
handle
=
amp_init
(
loss_scale
=
properties
.
loss_scale
,
patch_type
=
properties
.
patch_torch_functions_type
,
verbose
=
(
_amp_state
.
verbosity
==
2
))
for
optimizer
in
optimizers
:
# Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway.
...
...
apex/amp/_process_optimizer.py
View file @
c7fd532c
import
types
from
..fp16_utils
import
master_params_to_model_params
from
..multi_tensor_apply
import
multi_tensor_applier
from
._amp_state
import
maybe_print
from
._amp_state
import
maybe_print
,
_amp_state
import
torch
from
..optimizers
import
FusedSGD
...
...
@@ -13,7 +13,7 @@ class AmpOptimizerState(object):
def
_master_params_to_model_params
(
self
):
stash
=
self
.
_amp_stash
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
and
not
_amp_state
.
opt_properties
.
opt_level
not
in
{
"O4"
,
"O5"
}
:
if
len
(
stash
.
all_fp16_params
)
>
0
:
multi_tensor_applier
(
stash
.
multi_tensor_scale
,
...
...
@@ -37,7 +37,7 @@ def lazy_init_with_master_weights(self):
fp32_from_fp16_params_this_group
=
[]
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
requires_grad
:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
{
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
}
:
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
# .format(param.size()))
fp16_params_this_group
.
append
(
param
)
...
...
@@ -55,8 +55,8 @@ def lazy_init_with_master_weights(self):
fp32_params_this_group
.
append
(
param
)
param_group
[
'params'
][
i
]
=
param
else
:
raise
TypeError
(
"Optimizer's parameters must
be either
"
"torch.cuda.FloatTensor
or
torch.cuda.
Half
Tensor. "
raise
TypeError
(
"Optimizer's parameters must
one of
"
"torch.cuda.FloatTensor
, torch.cuda.HalfTensor,
torch.cuda.
BFloat16
Tensor. "
"Received {}"
.
format
(
param
.
type
()))
stash
.
fp16_groups
.
append
(
fp16_params_this_group
)
...
...
@@ -208,7 +208,7 @@ def lazy_init_no_master_weights(self):
stash
.
all_fp32_params
=
[]
for
i
,
param_group
in
enumerate
(
self
.
param_groups
):
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
{
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
}
:
stash
.
all_fp16_params
.
append
(
param
)
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
stash
.
all_fp32_params
.
append
(
param
)
...
...
@@ -337,7 +337,7 @@ def _process_optimizer(optimizer, properties):
raise
RuntimeError
(
"Incoming optimizer already has {} defined."
.
format
(
name
))
# TODO: Centralize exposure and import error checking for the C backend.
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
and
not
properties
.
opt_level
in
{
"O4"
,
"O5"
}
:
import
amp_C
optimizer
.
_amp_stash
.
multi_tensor_scale
=
amp_C
.
multi_tensor_scale
optimizer
.
_amp_stash
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
...
...
@@ -435,7 +435,7 @@ def _process_optimizer(optimizer, properties):
fp32_from_fp16_params_this_group
=
[]
for
i
,
param
in
enumerate
(
new_group
[
'params'
]):
if
param
.
requires_grad
:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
{
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
}
:
fp16_params_this_group
.
append
(
param
)
master_param
=
param
.
detach
().
clone
().
float
()
master_param
.
requires_grad
=
True
...
...
@@ -445,8 +445,8 @@ def _process_optimizer(optimizer, properties):
fp32_params_this_group
.
append
(
param
)
new_group
[
'params'
][
i
]
=
param
else
:
raise
TypeError
(
"Optimizer's parameters must be
either
"
"torch.cuda.FloatTensor
or
torch.cuda.
Half
Tensor. "
raise
TypeError
(
"Optimizer's parameters must be
one of
"
"torch.cuda.FloatTensor
, torch.cuda.HalfTensor,
torch.cuda.
BFloat16
Tensor. "
"Received {}"
.
format
(
param
.
type
()))
stash
.
fp16_groups
.
append
(
fp16_params_this_group
)
...
...
@@ -471,15 +471,15 @@ def _process_optimizer(optimizer, properties):
# param.grad = None
else
:
for
param
in
new_group
[
'params'
]:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
{
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
}
:
stash
.
all_fp16_params
.
append
(
param
)
stash
.
all_fp16_grad_stash
.
append
(
None
)
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
stash
.
all_fp32_params
.
append
(
param
)
stash
.
all_fp32_grad_stash
.
append
(
None
)
else
:
raise
TypeError
(
"Optimizer's parameters must
be either
"
"torch.cuda.FloatTensor
or
torch.cuda.
Half
Tensor. "
raise
TypeError
(
"Optimizer's parameters must
one of
"
"torch.cuda.FloatTensor
, torch.cuda.HalfTensor,
torch.cuda.
BFloat16
Tensor. "
"Received {}"
.
format
(
param
.
type
()))
old_add_param_group
(
new_group
)
...
...
apex/amp/amp.py
View file @
c7fd532c
...
...
@@ -9,7 +9,6 @@ import itertools
import
torch
_DECORATOR_HANDLE
=
None
_USER_CAST_REGISTRY
=
set
()
_USER_PROMOTE_REGISTRY
=
set
()
...
...
@@ -65,7 +64,7 @@ def register_promote_function(module, name):
# Top-level function to insert _all_ the hooks.
def
init
(
enabled
=
True
,
loss_scale
=
"dynamic"
,
enable_caching
=
True
,
verbose
=
False
,
allow_banned
=
False
):
def
init
(
enabled
=
True
,
loss_scale
=
"dynamic"
,
patch_type
=
torch
.
float16
,
enable_caching
=
True
,
verbose
=
False
,
allow_banned
=
False
):
global
_DECORATOR_HANDLE
if
not
enabled
:
...
...
@@ -87,27 +86,41 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
wrap
.
promote
(
mod
,
fn
,
handle
,
verbose
)
_USER_PROMOTE_REGISTRY
.
clear
()
# conditionally choose between fp16 and bfloat16 functions list to cache
if
patch_type
==
torch
.
float16
:
low_prec_funcs
=
'FP16_FUNCS'
maybe_low_prec
=
utils
.
maybe_half
low_prec_tensor
=
torch
.
cuda
.
HalfTensor
elif
patch_type
==
torch
.
bfloat16
:
low_prec_funcs
=
'BFLOAT16_FUNCS'
maybe_low_prec
=
utils
.
maybe_bfloat16
low_prec_tensor
=
torch
.
cuda
.
BFloat16Tensor
else
:
raise
RuntimeError
(
"Unsupported patch_torch_functions_type passed to initialize."
+
"Supported types are: torch.float16 and torch.bfloat16."
)
# 1) Force-{fp16, fp32} on white- / black-list functions
override_modules
=
[
functional_overrides
,
torch_overrides
,
tensor_overrides
]
cast_table
=
[(
'FP16_FUNCS'
,
utils
.
maybe_half
),
cast_table
=
[(
low_prec_funcs
,
maybe_low_prec
),
(
'FP32_FUNCS'
,
utils
.
maybe_float
)]
for
module
,
(
list_name
,
cast_fn
)
in
itertools
.
product
(
override_modules
,
cast_table
):
for
fn
in
getattr
(
module
,
list_name
):
try_caching
=
(
cast_fn
==
utils
.
maybe_half
)
try_caching
=
(
cast_fn
==
maybe_low_prec
)
wrap
.
cached_cast
(
module
.
MODULE
,
fn
,
cast_fn
,
handle
,
try_caching
,
verbose
)
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
# methods on FloatTensor, since they're distinct types.
if
compat
.
tensor_is_float_tensor
():
for
fn
in
tensor_overrides
.
FP16_FUNCS
:
wrap
.
cached_cast
(
torch
.
cuda
.
FloatTensor
,
fn
,
utils
.
maybe_
half
,
for
fn
in
getattr
(
tensor_overrides
,
low_prec_funcs
)
:
wrap
.
cached_cast
(
torch
.
cuda
.
FloatTensor
,
fn
,
utils
.
maybe_
low_prec
,
handle
,
try_caching
=
True
,
verbose
=
verbose
)
for
fn
in
tensor_overrides
.
FP32_FUNCS
:
wrap
.
cached_cast
(
torch
.
cuda
.
HalfT
ensor
,
fn
,
utils
.
maybe_float
,
wrap
.
cached_cast
(
low_prec_t
ensor
,
fn
,
utils
.
maybe_float
,
handle
,
try_caching
=
False
,
verbose
=
verbose
)
# 2) Enable type-promotion on multi-arg functions and methods.
...
...
@@ -123,7 +136,7 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if
compat
.
tensor_is_float_tensor
():
for
cls
,
(
list_name
,
promote_fn
)
in
itertools
.
product
([
torch
.
cuda
.
FloatTensor
,
torch
.
cuda
.
HalfT
ensor
],
low_prec_t
ensor
],
promote_table
):
for
fn
in
getattr
(
tensor_overrides
,
list_name
):
promote_fn
(
cls
,
fn
,
handle
,
verbose
)
...
...
@@ -141,11 +154,11 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
# 4) For other in-place methods, match the type of self tensor
for
fn
in
utils
.
as_inplace
(
itertools
.
chain
(
tensor_overrides
.
FP16_FUNCS
,
getattr
(
tensor_overrides
,
low_prec_funcs
)
,
tensor_overrides
.
CASTS
)):
wrap
.
promote_match_arg0
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
if
compat
.
tensor_is_float_tensor
():
wrap
.
promote_match_arg0
(
torch
.
cuda
.
HalfT
ensor
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
low_prec_t
ensor
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
torch
.
cuda
.
FloatTensor
,
fn
,
handle
,
verbose
)
# 5) RNNs + RNN cells are whitelisted specially
...
...
apex/amp/compat.py
View file @
c7fd532c
...
...
@@ -28,7 +28,8 @@ def is_floating_point(x):
torch_type
=
x
.
type
()
return
torch_type
.
endswith
(
'FloatTensor'
)
or
\
torch_type
.
endswith
(
'HalfTensor'
)
or
\
torch_type
.
endswith
(
'DoubleTensor'
)
torch_type
.
endswith
(
'DoubleTensor'
)
or
\
torch_type
.
endswith
(
'BFloat16Tensor'
)
except
AttributeError
:
return
False
...
...
apex/amp/frontend.py
View file @
c7fd532c
...
...
@@ -16,6 +16,7 @@ class Properties(object):
"opt_level"
:
None
,
"cast_model_type"
:
None
,
"patch_torch_functions"
:
False
,
"patch_torch_functions_type"
:
None
,
"keep_batchnorm_fp32"
:
None
,
"master_weights"
:
None
,
"loss_scale"
:
1.0
,
...
...
@@ -53,7 +54,7 @@ class Properties(object):
if
name
in
self
.
options
:
# print("setting {} {}".format(name, value))
if
name
==
"cast_model_type"
:
if
self
.
opt_level
==
"O1"
and
value
is
not
None
:
if
self
.
opt_level
in
{
"O1"
,
"O4"
}
and
value
is
not
None
:
if
value
is
not
False
:
if
value
is
not
torch
.
float32
:
warn_or_err
(
"O1 inserts casts around Torch functions rather than "
...
...
@@ -63,13 +64,25 @@ class Properties(object):
"cast_model_type was {}"
.
format
(
value
))
self
.
options
[
name
]
=
value
elif
name
==
"patch_torch_functions"
:
if
self
.
opt_level
!=
"O1"
and
value
:
if
self
.
opt_level
not
in
{
"O1"
,
"O4"
}
and
value
:
warn_or_err
(
"Currently, patch_torch_functions=True should only be set by "
"selecting opt_level='O1'."
)
"selecting opt_level='O1'
or 'O4'
."
)
self
.
options
[
name
]
=
value
elif
name
==
"patch_torch_functions_type"
:
if
self
.
opt_level
not
in
{
"O1"
,
"O4"
}
and
value
is
not
None
:
warn_or_err
(
"Currently, patch_torch_functions_type should only be set by "
"selecting opt_level='O1' or 'O4'."
)
elif
self
.
opt_level
==
"O1"
and
value
!=
torch
.
float16
:
warn_or_err
(
"patch_torch_functions_type should only be set to torch.float16 "
"for opt_level='O1."
)
elif
self
.
opt_level
==
"O4"
and
value
!=
torch
.
bfloat16
:
warn_or_err
(
"patch_torch_functions_type should only be set to torch.bfloat16 "
"for opt_level='O4."
)
else
:
self
.
options
[
name
]
=
value
elif
name
==
"keep_batchnorm_fp32"
:
if
self
.
opt_level
==
"O1"
and
value
is
not
None
:
warn_or_err
(
"With opt_level O1, batchnorm functions are automatically patched "
if
self
.
opt_level
in
{
"O1"
,
"O4"
}
and
value
is
not
None
:
warn_or_err
(
"With opt_level O1
or O4
, batchnorm functions are automatically patched "
"to run in FP32, so keep_batchnorm_fp32 should be None."
+
" keep_batchnorm_fp32 was {}"
.
format
(
value
))
if
value
==
"False"
:
...
...
@@ -82,9 +95,9 @@ class Properties(object):
"or None, found keep_batchnorm_fp32={}"
.
format
(
value
)
self
.
options
[
name
]
=
value
elif
name
==
"master_weights"
:
if
self
.
opt_level
==
"O1"
and
value
is
not
None
:
warn_or_err
(
"It doesn't make sense to use master_weights with O1. "
"With O1, your model weights themselves should be FP32."
)
if
self
.
opt_level
in
{
"O1"
,
"O4"
}
and
value
is
not
None
:
warn_or_err
(
"It doesn't make sense to use master_weights with O1
and O4
. "
"With O1
and O4
, your model weights themselves should be FP32."
)
self
.
options
[
name
]
=
value
elif
name
==
"loss_scale"
:
if
value
==
"dynamic"
:
...
...
@@ -113,6 +126,7 @@ class O3:
properties
.
opt_level
=
"O3"
properties
.
cast_model_type
=
torch
.
float16
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions_type
=
None
properties
.
keep_batchnorm_fp32
=
False
properties
.
master_weights
=
False
properties
.
loss_scale
=
1.0
...
...
@@ -136,6 +150,7 @@ class O2:
properties
.
opt_level
=
"O2"
properties
.
cast_model_type
=
torch
.
float16
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions_type
=
None
properties
.
keep_batchnorm_fp32
=
True
properties
.
master_weights
=
True
properties
.
loss_scale
=
"dynamic"
...
...
@@ -158,6 +173,7 @@ class O1:
properties
.
opt_level
=
"O1"
properties
.
cast_model_type
=
None
properties
.
patch_torch_functions
=
True
properties
.
patch_torch_functions_type
=
torch
.
float16
properties
.
keep_batchnorm_fp32
=
None
properties
.
master_weights
=
None
properties
.
loss_scale
=
"dynamic"
...
...
@@ -177,6 +193,7 @@ class O0:
properties
.
opt_level
=
"O0"
properties
.
cast_model_type
=
torch
.
float32
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions_type
=
None
properties
.
keep_batchnorm_fp32
=
None
properties
.
master_weights
=
False
properties
.
loss_scale
=
1.0
...
...
@@ -184,11 +201,54 @@ class O0:
# properties.enable_ddp_interop = False
return
properties
# modified in place so this isn't really necessary
class
O4
:
brief
=
"O4: Insert automatic casts around Pytorch functions and Tensor methods.
\n
"
more
=
"The type of your model's weights is not altered. However, internally,
\n
"
\
"Pytorch functions are patched to cast any Tensor Core-friendly ops to BFLOAT16 for speed,
\n
"
\
"while operations that might benefit from the additional stability of FP32 are patched
\n
"
\
"to cast their inputs to fp32.
\n
"
\
"Loss scaling is not required in O4 mode since bflaot16 has the same dynamic range as fp32."
def
__call__
(
self
,
properties
):
properties
.
enabled
=
True
properties
.
opt_level
=
"O4"
properties
.
cast_model_type
=
None
properties
.
patch_torch_functions
=
True
properties
.
patch_torch_functions_type
=
torch
.
bfloat16
properties
.
keep_batchnorm_fp32
=
None
properties
.
master_weights
=
None
properties
.
loss_scale
=
1
return
properties
# modified in place so this isn't really necessary
class
O5
:
brief
=
"O5: BFLOAT16 training with FP32 batchnorm and FP32 master weights.
\n
"
more
=
"Calls .bfloat16() on your model, converting the entire model (except for batchnorms)
\n
"
\
"to BFLOAT16. Batchnorms are retained in FP32 for additional stability.
\n
"
\
"The forward pass is patched to cast incoming Tensors to BFLOAT16, so you don't need to change
\n
"
\
"your data pipeline.
\n
"
\
"O5 creates FP32 master weights outside the model and patches any optimizers to update
\n
"
\
"these master weights, then copy the master weights into the BFLOAT16 model weights.
\n
"
\
"Master weights can also improve convergence and stability."
def
__call__
(
self
,
properties
):
properties
.
enabled
=
True
properties
.
opt_level
=
"O5"
properties
.
cast_model_type
=
torch
.
bfloat16
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions
=
None
properties
.
patch_torch_functions_type
=
None
properties
.
keep_batchnorm_fp32
=
True
properties
.
master_weights
=
True
properties
.
loss_scale
=
1
return
properties
# modified in place so this isn't really necessary
opt_levels
=
{
"O3"
:
O3
(),
"O2"
:
O2
(),
"O1"
:
O1
(),
"O0"
:
O0
()}
"O0"
:
O0
(),
"O4"
:
O4
(),
"O5"
:
O5
()}
# allow user to directly pass Properties struct as well?
...
...
@@ -199,6 +259,7 @@ def initialize(
opt_level
=
"O1"
,
cast_model_type
=
None
,
patch_torch_functions
=
None
,
patch_torch_functions_type
=
None
,
keep_batchnorm_fp32
=
None
,
master_weights
=
None
,
loss_scale
=
None
,
...
...
@@ -235,10 +296,11 @@ def initialize(
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
should run as if Amp were not present.
opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
"O0", "O1", "O2", and "O
3
", explained in detail above.
"O0", "O1", "O2",
"O3", "O4"
and "O
5
", explained in detail above.
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
above.
patch_torch_functions (bool, optional, default=None): Optional property override.
patch_torch_functions_type (``torch.dtype``, optional, default=None): Optional property override
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
passed as a string, must be the string "True" or "False".
master_weights (bool, optional, default=None): Optional property override.
...
...
@@ -321,7 +383,7 @@ def initialize(
if
opt_level
not
in
opt_levels
:
raise
RuntimeError
(
"Unexpected optimization level {}. "
.
format
(
opt_level
)
+
"Options are 'O0', 'O1', 'O2', 'O3'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, "
+
"Options are 'O0', 'O1', 'O2', 'O3'
, 'O4', 'O5'
. Note that in `O0`, `O1`, etc., the prefix O is the letter O, "
+
"not the number zero."
)
else
:
_amp_state
.
opt_properties
=
opt_levels
[
opt_level
](
_amp_state
.
opt_properties
)
...
...
@@ -344,6 +406,8 @@ def initialize(
_amp_state
.
opt_properties
.
cast_model_type
=
cast_model_type
if
patch_torch_functions
is
not
None
:
_amp_state
.
opt_properties
.
patch_torch_functions
=
patch_torch_functions
if
patch_torch_functions_type
is
not
None
:
_amp_state
.
opt_properties
.
patch_torch_functions_type
=
patch_torch_functions_type
if
keep_batchnorm_fp32
is
not
None
:
_amp_state
.
opt_properties
.
keep_batchnorm_fp32
=
keep_batchnorm_fp32
if
master_weights
is
not
None
:
...
...
apex/amp/lists/functional_overrides.py
View file @
c7fd532c
...
...
@@ -26,6 +26,17 @@ FP16_FUNCS = [
'linear'
,
]
BFLOAT16_FUNCS
=
[
'conv1d'
,
'conv2d'
,
'conv3d'
,
'conv_transpose1d'
,
'conv_transpose2d'
,
'conv_transpose3d'
,
'conv_tbc'
,
# Undocumented / maybe new?
'linear'
,
]
FP32_FUNCS
=
[
# Interpolation/Upsampling TODO: Remove for 1.2
...
...
apex/amp/lists/tensor_overrides.py
View file @
c7fd532c
...
...
@@ -15,6 +15,10 @@ FP16_FUNCS = [
'__matmul__'
,
]
BFLOAT16_FUNCS
=
[
'__matmul__'
,
]
FP32_FUNCS
=
[
'__ipow__'
,
'__pow__'
,
...
...
@@ -56,7 +60,7 @@ SEQUENCE_CASTS = []
# between `torch` and `torch.Tensor` (and check with `hasattr`,
# because a few random ones aren't defined on Tensor)
_self_mod
=
importlib
.
import_module
(
__name__
)
for
attrname
in
[
'FP16_FUNCS'
,
'FP32_FUNCS'
,
'CASTS'
,
'SEQUENCE_CASTS'
]:
for
attrname
in
[
'FP16_FUNCS'
,
'BFLOAT16_FUNCS'
,
'FP32_FUNCS'
,
'CASTS'
,
'SEQUENCE_CASTS'
]:
lst
=
getattr
(
_self_mod
,
attrname
)
for
fn
in
getattr
(
torch_overrides
,
attrname
):
if
hasattr
(
MODULE
,
fn
):
...
...
apex/amp/lists/torch_overrides.py
View file @
c7fd532c
...
...
@@ -26,6 +26,27 @@ FP16_FUNCS = [
'mv'
,
]
BFLOAT16_FUNCS
=
[
# Low level functions wrapped by torch.nn layers.
# The wrapper layers contain the weights which are then passed in as a parameter
# to these functions.
'conv1d'
,
'conv2d'
,
'conv3d'
,
'conv_transpose1d'
,
'conv_transpose2d'
,
'conv_transpose3d'
,
'conv_tbc'
,
# BLAS
'addmm'
,
'addmv'
,
'addr'
,
'matmul'
,
'mm'
,
'mv'
,
]
FP32_FUNCS
=
[
# Pointwise
'acos'
,
...
...
apex/amp/utils.py
View file @
c7fd532c
...
...
@@ -62,6 +62,17 @@ def maybe_half(x, name='', verbose=False):
print
(
'Float->Half ({})'
.
format
(
name
))
return
x
.
half
()
def
maybe_bfloat16
(
x
,
name
=
''
,
verbose
=
False
):
if
is_nested
(
x
):
return
type
(
x
)([
maybe_bfloat16
(
y
)
for
y
in
x
])
if
not
x
.
is_cuda
or
type_string
(
x
)
==
'BFloat16Tensor'
:
return
x
else
:
if
verbose
:
print
(
'Float->BFloat16 ({})'
.
format
(
name
))
return
x
.
bfloat16
()
def
maybe_float
(
x
,
name
=
''
,
verbose
=
False
):
if
is_nested
(
x
):
return
type
(
x
)([
maybe_float
(
y
)
for
y
in
x
])
...
...
apex/amp/wrap.py
View file @
c7fd532c
...
...
@@ -102,6 +102,8 @@ def promote_match_arg0(mod, fn, handle, verbose=False):
if
utils
.
type_string
(
arg0
)
==
'HalfTensor'
:
cast_fn
=
utils
.
maybe_half
if
utils
.
type_string
(
arg0
)
==
'BFloat16Tensor'
:
cast_fn
=
utils
.
maybe_bfloat16
elif
utils
.
type_string
(
arg0
)
==
'FloatTensor'
:
cast_fn
=
utils
.
maybe_float
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment