Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c7fd532c
Commit
c7fd532c
authored
May 08, 2020
by
rohithkrn
Browse files
basic enablement for O4 and O5 opt levels
parent
8124df13
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
170 additions
and
41 deletions
+170
-41
apex/amp/_initialize.py
apex/amp/_initialize.py
+7
-5
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+13
-13
apex/amp/amp.py
apex/amp/amp.py
+23
-10
apex/amp/compat.py
apex/amp/compat.py
+2
-1
apex/amp/frontend.py
apex/amp/frontend.py
+75
-11
apex/amp/lists/functional_overrides.py
apex/amp/lists/functional_overrides.py
+11
-0
apex/amp/lists/tensor_overrides.py
apex/amp/lists/tensor_overrides.py
+5
-1
apex/amp/lists/torch_overrides.py
apex/amp/lists/torch_overrides.py
+21
-0
apex/amp/utils.py
apex/amp/utils.py
+11
-0
apex/amp/wrap.py
apex/amp/wrap.py
+2
-0
No files found.
apex/amp/_initialize.py
View file @
c7fd532c
...
@@ -80,10 +80,10 @@ def check_params_fp32(models):
...
@@ -80,10 +80,10 @@ def check_params_fp32(models):
for
model
in
models
:
for
model
in
models
:
for
name
,
param
in
model
.
named_parameters
():
for
name
,
param
in
model
.
named_parameters
():
if
param
.
is_floating_point
():
if
param
.
is_floating_point
():
if
'Half'
in
param
.
type
():
if
'Half'
in
param
.
type
()
or
'BFloat16'
in
param
.
type
():
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.initialize, you do not need to call .half() o
n your model
\n
"
"When using amp.initialize, you do not need to call .half() o
r .bfloat16()
\n
"
"before passing it, no matter what optimization level you choose."
.
format
(
"
on your model
before passing it, no matter what optimization level you choose."
.
format
(
name
,
param
.
type
()))
name
,
param
.
type
()))
elif
not
param
.
is_cuda
:
elif
not
param
.
is_cuda
:
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
...
@@ -137,7 +137,7 @@ class O2StateDictHook(object):
...
@@ -137,7 +137,7 @@ class O2StateDictHook(object):
def
__call__
(
self
,
module
,
state_dict
,
prefix
,
local_metadata
):
def
__call__
(
self
,
module
,
state_dict
,
prefix
,
local_metadata
):
for
key
in
state_dict
:
for
key
in
state_dict
:
param
=
state_dict
[
key
]
param
=
state_dict
[
key
]
if
'Half'
in
param
.
type
():
if
'Half'
in
param
.
type
()
or
'BFloat16'
in
param
.
type
():
param
=
param
.
to
(
torch
.
float32
)
param
=
param
.
to
(
torch
.
float32
)
state_dict
[
key
]
=
param
state_dict
[
key
]
=
param
...
@@ -232,7 +232,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
...
@@ -232,7 +232,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if
properties
.
patch_torch_functions
:
if
properties
.
patch_torch_functions
:
# handle is unused here. It's accessible later through a global value anyway.
# handle is unused here. It's accessible later through a global value anyway.
handle
=
amp_init
(
loss_scale
=
properties
.
loss_scale
,
verbose
=
(
_amp_state
.
verbosity
==
2
))
handle
=
amp_init
(
loss_scale
=
properties
.
loss_scale
,
patch_type
=
properties
.
patch_torch_functions_type
,
verbose
=
(
_amp_state
.
verbosity
==
2
))
for
optimizer
in
optimizers
:
for
optimizer
in
optimizers
:
# Disable Amp casting for the optimizer step, because it should only be
# Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway.
# applied to FP32 master params anyway.
...
...
apex/amp/_process_optimizer.py
View file @
c7fd532c
import
types
import
types
from
..fp16_utils
import
master_params_to_model_params
from
..fp16_utils
import
master_params_to_model_params
from
..multi_tensor_apply
import
multi_tensor_applier
from
..multi_tensor_apply
import
multi_tensor_applier
from
._amp_state
import
maybe_print
from
._amp_state
import
maybe_print
,
_amp_state
import
torch
import
torch
from
..optimizers
import
FusedSGD
from
..optimizers
import
FusedSGD
...
@@ -13,7 +13,7 @@ class AmpOptimizerState(object):
...
@@ -13,7 +13,7 @@ class AmpOptimizerState(object):
def
_master_params_to_model_params
(
self
):
def
_master_params_to_model_params
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
and
not
_amp_state
.
opt_properties
.
opt_level
not
in
{
"O4"
,
"O5"
}
:
if
len
(
stash
.
all_fp16_params
)
>
0
:
if
len
(
stash
.
all_fp16_params
)
>
0
:
multi_tensor_applier
(
multi_tensor_applier
(
stash
.
multi_tensor_scale
,
stash
.
multi_tensor_scale
,
...
@@ -37,7 +37,7 @@ def lazy_init_with_master_weights(self):
...
@@ -37,7 +37,7 @@ def lazy_init_with_master_weights(self):
fp32_from_fp16_params_this_group
=
[]
fp32_from_fp16_params_this_group
=
[]
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
requires_grad
:
if
param
.
requires_grad
:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
{
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
}
:
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
# .format(param.size()))
# .format(param.size()))
fp16_params_this_group
.
append
(
param
)
fp16_params_this_group
.
append
(
param
)
...
@@ -55,8 +55,8 @@ def lazy_init_with_master_weights(self):
...
@@ -55,8 +55,8 @@ def lazy_init_with_master_weights(self):
fp32_params_this_group
.
append
(
param
)
fp32_params_this_group
.
append
(
param
)
param_group
[
'params'
][
i
]
=
param
param_group
[
'params'
][
i
]
=
param
else
:
else
:
raise
TypeError
(
"Optimizer's parameters must
be either
"
raise
TypeError
(
"Optimizer's parameters must
one of
"
"torch.cuda.FloatTensor
or
torch.cuda.
Half
Tensor. "
"torch.cuda.FloatTensor
, torch.cuda.HalfTensor,
torch.cuda.
BFloat16
Tensor. "
"Received {}"
.
format
(
param
.
type
()))
"Received {}"
.
format
(
param
.
type
()))
stash
.
fp16_groups
.
append
(
fp16_params_this_group
)
stash
.
fp16_groups
.
append
(
fp16_params_this_group
)
...
@@ -208,7 +208,7 @@ def lazy_init_no_master_weights(self):
...
@@ -208,7 +208,7 @@ def lazy_init_no_master_weights(self):
stash
.
all_fp32_params
=
[]
stash
.
all_fp32_params
=
[]
for
i
,
param_group
in
enumerate
(
self
.
param_groups
):
for
i
,
param_group
in
enumerate
(
self
.
param_groups
):
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
{
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
}
:
stash
.
all_fp16_params
.
append
(
param
)
stash
.
all_fp16_params
.
append
(
param
)
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
stash
.
all_fp32_params
.
append
(
param
)
stash
.
all_fp32_params
.
append
(
param
)
...
@@ -337,7 +337,7 @@ def _process_optimizer(optimizer, properties):
...
@@ -337,7 +337,7 @@ def _process_optimizer(optimizer, properties):
raise
RuntimeError
(
"Incoming optimizer already has {} defined."
.
format
(
name
))
raise
RuntimeError
(
"Incoming optimizer already has {} defined."
.
format
(
name
))
# TODO: Centralize exposure and import error checking for the C backend.
# TODO: Centralize exposure and import error checking for the C backend.
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
and
not
properties
.
opt_level
in
{
"O4"
,
"O5"
}
:
import
amp_C
import
amp_C
optimizer
.
_amp_stash
.
multi_tensor_scale
=
amp_C
.
multi_tensor_scale
optimizer
.
_amp_stash
.
multi_tensor_scale
=
amp_C
.
multi_tensor_scale
optimizer
.
_amp_stash
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
optimizer
.
_amp_stash
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
...
@@ -435,7 +435,7 @@ def _process_optimizer(optimizer, properties):
...
@@ -435,7 +435,7 @@ def _process_optimizer(optimizer, properties):
fp32_from_fp16_params_this_group
=
[]
fp32_from_fp16_params_this_group
=
[]
for
i
,
param
in
enumerate
(
new_group
[
'params'
]):
for
i
,
param
in
enumerate
(
new_group
[
'params'
]):
if
param
.
requires_grad
:
if
param
.
requires_grad
:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
{
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
}
:
fp16_params_this_group
.
append
(
param
)
fp16_params_this_group
.
append
(
param
)
master_param
=
param
.
detach
().
clone
().
float
()
master_param
=
param
.
detach
().
clone
().
float
()
master_param
.
requires_grad
=
True
master_param
.
requires_grad
=
True
...
@@ -445,8 +445,8 @@ def _process_optimizer(optimizer, properties):
...
@@ -445,8 +445,8 @@ def _process_optimizer(optimizer, properties):
fp32_params_this_group
.
append
(
param
)
fp32_params_this_group
.
append
(
param
)
new_group
[
'params'
][
i
]
=
param
new_group
[
'params'
][
i
]
=
param
else
:
else
:
raise
TypeError
(
"Optimizer's parameters must be
either
"
raise
TypeError
(
"Optimizer's parameters must be
one of
"
"torch.cuda.FloatTensor
or
torch.cuda.
Half
Tensor. "
"torch.cuda.FloatTensor
, torch.cuda.HalfTensor,
torch.cuda.
BFloat16
Tensor. "
"Received {}"
.
format
(
param
.
type
()))
"Received {}"
.
format
(
param
.
type
()))
stash
.
fp16_groups
.
append
(
fp16_params_this_group
)
stash
.
fp16_groups
.
append
(
fp16_params_this_group
)
...
@@ -471,15 +471,15 @@ def _process_optimizer(optimizer, properties):
...
@@ -471,15 +471,15 @@ def _process_optimizer(optimizer, properties):
# param.grad = None
# param.grad = None
else
:
else
:
for
param
in
new_group
[
'params'
]:
for
param
in
new_group
[
'params'
]:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
{
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
}
:
stash
.
all_fp16_params
.
append
(
param
)
stash
.
all_fp16_params
.
append
(
param
)
stash
.
all_fp16_grad_stash
.
append
(
None
)
stash
.
all_fp16_grad_stash
.
append
(
None
)
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
stash
.
all_fp32_params
.
append
(
param
)
stash
.
all_fp32_params
.
append
(
param
)
stash
.
all_fp32_grad_stash
.
append
(
None
)
stash
.
all_fp32_grad_stash
.
append
(
None
)
else
:
else
:
raise
TypeError
(
"Optimizer's parameters must
be either
"
raise
TypeError
(
"Optimizer's parameters must
one of
"
"torch.cuda.FloatTensor
or
torch.cuda.
Half
Tensor. "
"torch.cuda.FloatTensor
, torch.cuda.HalfTensor,
torch.cuda.
BFloat16
Tensor. "
"Received {}"
.
format
(
param
.
type
()))
"Received {}"
.
format
(
param
.
type
()))
old_add_param_group
(
new_group
)
old_add_param_group
(
new_group
)
...
...
apex/amp/amp.py
View file @
c7fd532c
...
@@ -9,7 +9,6 @@ import itertools
...
@@ -9,7 +9,6 @@ import itertools
import
torch
import
torch
_DECORATOR_HANDLE
=
None
_DECORATOR_HANDLE
=
None
_USER_CAST_REGISTRY
=
set
()
_USER_CAST_REGISTRY
=
set
()
_USER_PROMOTE_REGISTRY
=
set
()
_USER_PROMOTE_REGISTRY
=
set
()
...
@@ -65,7 +64,7 @@ def register_promote_function(module, name):
...
@@ -65,7 +64,7 @@ def register_promote_function(module, name):
# Top-level function to insert _all_ the hooks.
# Top-level function to insert _all_ the hooks.
def
init
(
enabled
=
True
,
loss_scale
=
"dynamic"
,
enable_caching
=
True
,
verbose
=
False
,
allow_banned
=
False
):
def
init
(
enabled
=
True
,
loss_scale
=
"dynamic"
,
patch_type
=
torch
.
float16
,
enable_caching
=
True
,
verbose
=
False
,
allow_banned
=
False
):
global
_DECORATOR_HANDLE
global
_DECORATOR_HANDLE
if
not
enabled
:
if
not
enabled
:
...
@@ -87,27 +86,41 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
...
@@ -87,27 +86,41 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
wrap
.
promote
(
mod
,
fn
,
handle
,
verbose
)
wrap
.
promote
(
mod
,
fn
,
handle
,
verbose
)
_USER_PROMOTE_REGISTRY
.
clear
()
_USER_PROMOTE_REGISTRY
.
clear
()
# conditionally choose between fp16 and bfloat16 functions list to cache
if
patch_type
==
torch
.
float16
:
low_prec_funcs
=
'FP16_FUNCS'
maybe_low_prec
=
utils
.
maybe_half
low_prec_tensor
=
torch
.
cuda
.
HalfTensor
elif
patch_type
==
torch
.
bfloat16
:
low_prec_funcs
=
'BFLOAT16_FUNCS'
maybe_low_prec
=
utils
.
maybe_bfloat16
low_prec_tensor
=
torch
.
cuda
.
BFloat16Tensor
else
:
raise
RuntimeError
(
"Unsupported patch_torch_functions_type passed to initialize."
+
"Supported types are: torch.float16 and torch.bfloat16."
)
# 1) Force-{fp16, fp32} on white- / black-list functions
# 1) Force-{fp16, fp32} on white- / black-list functions
override_modules
=
[
functional_overrides
,
override_modules
=
[
functional_overrides
,
torch_overrides
,
torch_overrides
,
tensor_overrides
]
tensor_overrides
]
cast_table
=
[(
'FP16_FUNCS'
,
utils
.
maybe_half
),
cast_table
=
[(
low_prec_funcs
,
maybe_low_prec
),
(
'FP32_FUNCS'
,
utils
.
maybe_float
)]
(
'FP32_FUNCS'
,
utils
.
maybe_float
)]
for
module
,
(
list_name
,
cast_fn
)
in
itertools
.
product
(
override_modules
,
for
module
,
(
list_name
,
cast_fn
)
in
itertools
.
product
(
override_modules
,
cast_table
):
cast_table
):
for
fn
in
getattr
(
module
,
list_name
):
for
fn
in
getattr
(
module
,
list_name
):
try_caching
=
(
cast_fn
==
utils
.
maybe_half
)
try_caching
=
(
cast_fn
==
maybe_low_prec
)
wrap
.
cached_cast
(
module
.
MODULE
,
fn
,
cast_fn
,
handle
,
wrap
.
cached_cast
(
module
.
MODULE
,
fn
,
cast_fn
,
handle
,
try_caching
,
verbose
)
try_caching
,
verbose
)
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
# methods on FloatTensor, since they're distinct types.
# methods on FloatTensor, since they're distinct types.
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
():
for
fn
in
tensor_overrides
.
FP16_FUNCS
:
for
fn
in
getattr
(
tensor_overrides
,
low_prec_funcs
)
:
wrap
.
cached_cast
(
torch
.
cuda
.
FloatTensor
,
fn
,
utils
.
maybe_
half
,
wrap
.
cached_cast
(
torch
.
cuda
.
FloatTensor
,
fn
,
utils
.
maybe_
low_prec
,
handle
,
try_caching
=
True
,
verbose
=
verbose
)
handle
,
try_caching
=
True
,
verbose
=
verbose
)
for
fn
in
tensor_overrides
.
FP32_FUNCS
:
for
fn
in
tensor_overrides
.
FP32_FUNCS
:
wrap
.
cached_cast
(
torch
.
cuda
.
HalfT
ensor
,
fn
,
utils
.
maybe_float
,
wrap
.
cached_cast
(
low_prec_t
ensor
,
fn
,
utils
.
maybe_float
,
handle
,
try_caching
=
False
,
verbose
=
verbose
)
handle
,
try_caching
=
False
,
verbose
=
verbose
)
# 2) Enable type-promotion on multi-arg functions and methods.
# 2) Enable type-promotion on multi-arg functions and methods.
...
@@ -123,7 +136,7 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
...
@@ -123,7 +136,7 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
():
for
cls
,
(
list_name
,
promote_fn
)
in
itertools
.
product
([
torch
.
cuda
.
FloatTensor
,
for
cls
,
(
list_name
,
promote_fn
)
in
itertools
.
product
([
torch
.
cuda
.
FloatTensor
,
torch
.
cuda
.
HalfT
ensor
],
low_prec_t
ensor
],
promote_table
):
promote_table
):
for
fn
in
getattr
(
tensor_overrides
,
list_name
):
for
fn
in
getattr
(
tensor_overrides
,
list_name
):
promote_fn
(
cls
,
fn
,
handle
,
verbose
)
promote_fn
(
cls
,
fn
,
handle
,
verbose
)
...
@@ -141,11 +154,11 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
...
@@ -141,11 +154,11 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
# 4) For other in-place methods, match the type of self tensor
# 4) For other in-place methods, match the type of self tensor
for
fn
in
utils
.
as_inplace
(
itertools
.
chain
(
for
fn
in
utils
.
as_inplace
(
itertools
.
chain
(
tensor_overrides
.
FP16_FUNCS
,
getattr
(
tensor_overrides
,
low_prec_funcs
)
,
tensor_overrides
.
CASTS
)):
tensor_overrides
.
CASTS
)):
wrap
.
promote_match_arg0
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
():
wrap
.
promote_match_arg0
(
torch
.
cuda
.
HalfT
ensor
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
low_prec_t
ensor
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
torch
.
cuda
.
FloatTensor
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
torch
.
cuda
.
FloatTensor
,
fn
,
handle
,
verbose
)
# 5) RNNs + RNN cells are whitelisted specially
# 5) RNNs + RNN cells are whitelisted specially
...
...
apex/amp/compat.py
View file @
c7fd532c
...
@@ -28,7 +28,8 @@ def is_floating_point(x):
...
@@ -28,7 +28,8 @@ def is_floating_point(x):
torch_type
=
x
.
type
()
torch_type
=
x
.
type
()
return
torch_type
.
endswith
(
'FloatTensor'
)
or
\
return
torch_type
.
endswith
(
'FloatTensor'
)
or
\
torch_type
.
endswith
(
'HalfTensor'
)
or
\
torch_type
.
endswith
(
'HalfTensor'
)
or
\
torch_type
.
endswith
(
'DoubleTensor'
)
torch_type
.
endswith
(
'DoubleTensor'
)
or
\
torch_type
.
endswith
(
'BFloat16Tensor'
)
except
AttributeError
:
except
AttributeError
:
return
False
return
False
...
...
apex/amp/frontend.py
View file @
c7fd532c
...
@@ -16,6 +16,7 @@ class Properties(object):
...
@@ -16,6 +16,7 @@ class Properties(object):
"opt_level"
:
None
,
"opt_level"
:
None
,
"cast_model_type"
:
None
,
"cast_model_type"
:
None
,
"patch_torch_functions"
:
False
,
"patch_torch_functions"
:
False
,
"patch_torch_functions_type"
:
None
,
"keep_batchnorm_fp32"
:
None
,
"keep_batchnorm_fp32"
:
None
,
"master_weights"
:
None
,
"master_weights"
:
None
,
"loss_scale"
:
1.0
,
"loss_scale"
:
1.0
,
...
@@ -53,7 +54,7 @@ class Properties(object):
...
@@ -53,7 +54,7 @@ class Properties(object):
if
name
in
self
.
options
:
if
name
in
self
.
options
:
# print("setting {} {}".format(name, value))
# print("setting {} {}".format(name, value))
if
name
==
"cast_model_type"
:
if
name
==
"cast_model_type"
:
if
self
.
opt_level
==
"O1"
and
value
is
not
None
:
if
self
.
opt_level
in
{
"O1"
,
"O4"
}
and
value
is
not
None
:
if
value
is
not
False
:
if
value
is
not
False
:
if
value
is
not
torch
.
float32
:
if
value
is
not
torch
.
float32
:
warn_or_err
(
"O1 inserts casts around Torch functions rather than "
warn_or_err
(
"O1 inserts casts around Torch functions rather than "
...
@@ -63,13 +64,25 @@ class Properties(object):
...
@@ -63,13 +64,25 @@ class Properties(object):
"cast_model_type was {}"
.
format
(
value
))
"cast_model_type was {}"
.
format
(
value
))
self
.
options
[
name
]
=
value
self
.
options
[
name
]
=
value
elif
name
==
"patch_torch_functions"
:
elif
name
==
"patch_torch_functions"
:
if
self
.
opt_level
!=
"O1"
and
value
:
if
self
.
opt_level
not
in
{
"O1"
,
"O4"
}
and
value
:
warn_or_err
(
"Currently, patch_torch_functions=True should only be set by "
warn_or_err
(
"Currently, patch_torch_functions=True should only be set by "
"selecting opt_level='O1'."
)
"selecting opt_level='O1'
or 'O4'
."
)
self
.
options
[
name
]
=
value
self
.
options
[
name
]
=
value
elif
name
==
"patch_torch_functions_type"
:
if
self
.
opt_level
not
in
{
"O1"
,
"O4"
}
and
value
is
not
None
:
warn_or_err
(
"Currently, patch_torch_functions_type should only be set by "
"selecting opt_level='O1' or 'O4'."
)
elif
self
.
opt_level
==
"O1"
and
value
!=
torch
.
float16
:
warn_or_err
(
"patch_torch_functions_type should only be set to torch.float16 "
"for opt_level='O1."
)
elif
self
.
opt_level
==
"O4"
and
value
!=
torch
.
bfloat16
:
warn_or_err
(
"patch_torch_functions_type should only be set to torch.bfloat16 "
"for opt_level='O4."
)
else
:
self
.
options
[
name
]
=
value
elif
name
==
"keep_batchnorm_fp32"
:
elif
name
==
"keep_batchnorm_fp32"
:
if
self
.
opt_level
==
"O1"
and
value
is
not
None
:
if
self
.
opt_level
in
{
"O1"
,
"O4"
}
and
value
is
not
None
:
warn_or_err
(
"With opt_level O1, batchnorm functions are automatically patched "
warn_or_err
(
"With opt_level O1
or O4
, batchnorm functions are automatically patched "
"to run in FP32, so keep_batchnorm_fp32 should be None."
+
"to run in FP32, so keep_batchnorm_fp32 should be None."
+
" keep_batchnorm_fp32 was {}"
.
format
(
value
))
" keep_batchnorm_fp32 was {}"
.
format
(
value
))
if
value
==
"False"
:
if
value
==
"False"
:
...
@@ -82,9 +95,9 @@ class Properties(object):
...
@@ -82,9 +95,9 @@ class Properties(object):
"or None, found keep_batchnorm_fp32={}"
.
format
(
value
)
"or None, found keep_batchnorm_fp32={}"
.
format
(
value
)
self
.
options
[
name
]
=
value
self
.
options
[
name
]
=
value
elif
name
==
"master_weights"
:
elif
name
==
"master_weights"
:
if
self
.
opt_level
==
"O1"
and
value
is
not
None
:
if
self
.
opt_level
in
{
"O1"
,
"O4"
}
and
value
is
not
None
:
warn_or_err
(
"It doesn't make sense to use master_weights with O1. "
warn_or_err
(
"It doesn't make sense to use master_weights with O1
and O4
. "
"With O1, your model weights themselves should be FP32."
)
"With O1
and O4
, your model weights themselves should be FP32."
)
self
.
options
[
name
]
=
value
self
.
options
[
name
]
=
value
elif
name
==
"loss_scale"
:
elif
name
==
"loss_scale"
:
if
value
==
"dynamic"
:
if
value
==
"dynamic"
:
...
@@ -113,6 +126,7 @@ class O3:
...
@@ -113,6 +126,7 @@ class O3:
properties
.
opt_level
=
"O3"
properties
.
opt_level
=
"O3"
properties
.
cast_model_type
=
torch
.
float16
properties
.
cast_model_type
=
torch
.
float16
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions_type
=
None
properties
.
keep_batchnorm_fp32
=
False
properties
.
keep_batchnorm_fp32
=
False
properties
.
master_weights
=
False
properties
.
master_weights
=
False
properties
.
loss_scale
=
1.0
properties
.
loss_scale
=
1.0
...
@@ -136,6 +150,7 @@ class O2:
...
@@ -136,6 +150,7 @@ class O2:
properties
.
opt_level
=
"O2"
properties
.
opt_level
=
"O2"
properties
.
cast_model_type
=
torch
.
float16
properties
.
cast_model_type
=
torch
.
float16
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions_type
=
None
properties
.
keep_batchnorm_fp32
=
True
properties
.
keep_batchnorm_fp32
=
True
properties
.
master_weights
=
True
properties
.
master_weights
=
True
properties
.
loss_scale
=
"dynamic"
properties
.
loss_scale
=
"dynamic"
...
@@ -158,6 +173,7 @@ class O1:
...
@@ -158,6 +173,7 @@ class O1:
properties
.
opt_level
=
"O1"
properties
.
opt_level
=
"O1"
properties
.
cast_model_type
=
None
properties
.
cast_model_type
=
None
properties
.
patch_torch_functions
=
True
properties
.
patch_torch_functions
=
True
properties
.
patch_torch_functions_type
=
torch
.
float16
properties
.
keep_batchnorm_fp32
=
None
properties
.
keep_batchnorm_fp32
=
None
properties
.
master_weights
=
None
properties
.
master_weights
=
None
properties
.
loss_scale
=
"dynamic"
properties
.
loss_scale
=
"dynamic"
...
@@ -177,6 +193,7 @@ class O0:
...
@@ -177,6 +193,7 @@ class O0:
properties
.
opt_level
=
"O0"
properties
.
opt_level
=
"O0"
properties
.
cast_model_type
=
torch
.
float32
properties
.
cast_model_type
=
torch
.
float32
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions_type
=
None
properties
.
keep_batchnorm_fp32
=
None
properties
.
keep_batchnorm_fp32
=
None
properties
.
master_weights
=
False
properties
.
master_weights
=
False
properties
.
loss_scale
=
1.0
properties
.
loss_scale
=
1.0
...
@@ -184,11 +201,54 @@ class O0:
...
@@ -184,11 +201,54 @@ class O0:
# properties.enable_ddp_interop = False
# properties.enable_ddp_interop = False
return
properties
# modified in place so this isn't really necessary
return
properties
# modified in place so this isn't really necessary
class
O4
:
brief
=
"O4: Insert automatic casts around Pytorch functions and Tensor methods.
\n
"
more
=
"The type of your model's weights is not altered. However, internally,
\n
"
\
"Pytorch functions are patched to cast any Tensor Core-friendly ops to BFLOAT16 for speed,
\n
"
\
"while operations that might benefit from the additional stability of FP32 are patched
\n
"
\
"to cast their inputs to fp32.
\n
"
\
"Loss scaling is not required in O4 mode since bflaot16 has the same dynamic range as fp32."
def
__call__
(
self
,
properties
):
properties
.
enabled
=
True
properties
.
opt_level
=
"O4"
properties
.
cast_model_type
=
None
properties
.
patch_torch_functions
=
True
properties
.
patch_torch_functions_type
=
torch
.
bfloat16
properties
.
keep_batchnorm_fp32
=
None
properties
.
master_weights
=
None
properties
.
loss_scale
=
1
return
properties
# modified in place so this isn't really necessary
class
O5
:
brief
=
"O5: BFLOAT16 training with FP32 batchnorm and FP32 master weights.
\n
"
more
=
"Calls .bfloat16() on your model, converting the entire model (except for batchnorms)
\n
"
\
"to BFLOAT16. Batchnorms are retained in FP32 for additional stability.
\n
"
\
"The forward pass is patched to cast incoming Tensors to BFLOAT16, so you don't need to change
\n
"
\
"your data pipeline.
\n
"
\
"O5 creates FP32 master weights outside the model and patches any optimizers to update
\n
"
\
"these master weights, then copy the master weights into the BFLOAT16 model weights.
\n
"
\
"Master weights can also improve convergence and stability."
def
__call__
(
self
,
properties
):
properties
.
enabled
=
True
properties
.
opt_level
=
"O5"
properties
.
cast_model_type
=
torch
.
bfloat16
properties
.
patch_torch_functions
=
False
properties
.
patch_torch_functions
=
None
properties
.
patch_torch_functions_type
=
None
properties
.
keep_batchnorm_fp32
=
True
properties
.
master_weights
=
True
properties
.
loss_scale
=
1
return
properties
# modified in place so this isn't really necessary
opt_levels
=
{
"O3"
:
O3
(),
opt_levels
=
{
"O3"
:
O3
(),
"O2"
:
O2
(),
"O2"
:
O2
(),
"O1"
:
O1
(),
"O1"
:
O1
(),
"O0"
:
O0
()}
"O0"
:
O0
(),
"O4"
:
O4
(),
"O5"
:
O5
()}
# allow user to directly pass Properties struct as well?
# allow user to directly pass Properties struct as well?
...
@@ -199,6 +259,7 @@ def initialize(
...
@@ -199,6 +259,7 @@ def initialize(
opt_level
=
"O1"
,
opt_level
=
"O1"
,
cast_model_type
=
None
,
cast_model_type
=
None
,
patch_torch_functions
=
None
,
patch_torch_functions
=
None
,
patch_torch_functions_type
=
None
,
keep_batchnorm_fp32
=
None
,
keep_batchnorm_fp32
=
None
,
master_weights
=
None
,
master_weights
=
None
,
loss_scale
=
None
,
loss_scale
=
None
,
...
@@ -235,10 +296,11 @@ def initialize(
...
@@ -235,10 +296,11 @@ def initialize(
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
should run as if Amp were not present.
should run as if Amp were not present.
opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
"O0", "O1", "O2", and "O
3
", explained in detail above.
"O0", "O1", "O2",
"O3", "O4"
and "O
5
", explained in detail above.
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
above.
above.
patch_torch_functions (bool, optional, default=None): Optional property override.
patch_torch_functions (bool, optional, default=None): Optional property override.
patch_torch_functions_type (``torch.dtype``, optional, default=None): Optional property override
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
passed as a string, must be the string "True" or "False".
passed as a string, must be the string "True" or "False".
master_weights (bool, optional, default=None): Optional property override.
master_weights (bool, optional, default=None): Optional property override.
...
@@ -321,7 +383,7 @@ def initialize(
...
@@ -321,7 +383,7 @@ def initialize(
if
opt_level
not
in
opt_levels
:
if
opt_level
not
in
opt_levels
:
raise
RuntimeError
(
raise
RuntimeError
(
"Unexpected optimization level {}. "
.
format
(
opt_level
)
+
"Unexpected optimization level {}. "
.
format
(
opt_level
)
+
"Options are 'O0', 'O1', 'O2', 'O3'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, "
+
"Options are 'O0', 'O1', 'O2', 'O3'
, 'O4', 'O5'
. Note that in `O0`, `O1`, etc., the prefix O is the letter O, "
+
"not the number zero."
)
"not the number zero."
)
else
:
else
:
_amp_state
.
opt_properties
=
opt_levels
[
opt_level
](
_amp_state
.
opt_properties
)
_amp_state
.
opt_properties
=
opt_levels
[
opt_level
](
_amp_state
.
opt_properties
)
...
@@ -344,6 +406,8 @@ def initialize(
...
@@ -344,6 +406,8 @@ def initialize(
_amp_state
.
opt_properties
.
cast_model_type
=
cast_model_type
_amp_state
.
opt_properties
.
cast_model_type
=
cast_model_type
if
patch_torch_functions
is
not
None
:
if
patch_torch_functions
is
not
None
:
_amp_state
.
opt_properties
.
patch_torch_functions
=
patch_torch_functions
_amp_state
.
opt_properties
.
patch_torch_functions
=
patch_torch_functions
if
patch_torch_functions_type
is
not
None
:
_amp_state
.
opt_properties
.
patch_torch_functions_type
=
patch_torch_functions_type
if
keep_batchnorm_fp32
is
not
None
:
if
keep_batchnorm_fp32
is
not
None
:
_amp_state
.
opt_properties
.
keep_batchnorm_fp32
=
keep_batchnorm_fp32
_amp_state
.
opt_properties
.
keep_batchnorm_fp32
=
keep_batchnorm_fp32
if
master_weights
is
not
None
:
if
master_weights
is
not
None
:
...
...
apex/amp/lists/functional_overrides.py
View file @
c7fd532c
...
@@ -26,6 +26,17 @@ FP16_FUNCS = [
...
@@ -26,6 +26,17 @@ FP16_FUNCS = [
'linear'
,
'linear'
,
]
]
BFLOAT16_FUNCS
=
[
'conv1d'
,
'conv2d'
,
'conv3d'
,
'conv_transpose1d'
,
'conv_transpose2d'
,
'conv_transpose3d'
,
'conv_tbc'
,
# Undocumented / maybe new?
'linear'
,
]
FP32_FUNCS
=
[
FP32_FUNCS
=
[
# Interpolation/Upsampling TODO: Remove for 1.2
# Interpolation/Upsampling TODO: Remove for 1.2
...
...
apex/amp/lists/tensor_overrides.py
View file @
c7fd532c
...
@@ -15,6 +15,10 @@ FP16_FUNCS = [
...
@@ -15,6 +15,10 @@ FP16_FUNCS = [
'__matmul__'
,
'__matmul__'
,
]
]
BFLOAT16_FUNCS
=
[
'__matmul__'
,
]
FP32_FUNCS
=
[
FP32_FUNCS
=
[
'__ipow__'
,
'__ipow__'
,
'__pow__'
,
'__pow__'
,
...
@@ -56,7 +60,7 @@ SEQUENCE_CASTS = []
...
@@ -56,7 +60,7 @@ SEQUENCE_CASTS = []
# between `torch` and `torch.Tensor` (and check with `hasattr`,
# between `torch` and `torch.Tensor` (and check with `hasattr`,
# because a few random ones aren't defined on Tensor)
# because a few random ones aren't defined on Tensor)
_self_mod
=
importlib
.
import_module
(
__name__
)
_self_mod
=
importlib
.
import_module
(
__name__
)
for
attrname
in
[
'FP16_FUNCS'
,
'FP32_FUNCS'
,
'CASTS'
,
'SEQUENCE_CASTS'
]:
for
attrname
in
[
'FP16_FUNCS'
,
'BFLOAT16_FUNCS'
,
'FP32_FUNCS'
,
'CASTS'
,
'SEQUENCE_CASTS'
]:
lst
=
getattr
(
_self_mod
,
attrname
)
lst
=
getattr
(
_self_mod
,
attrname
)
for
fn
in
getattr
(
torch_overrides
,
attrname
):
for
fn
in
getattr
(
torch_overrides
,
attrname
):
if
hasattr
(
MODULE
,
fn
):
if
hasattr
(
MODULE
,
fn
):
...
...
apex/amp/lists/torch_overrides.py
View file @
c7fd532c
...
@@ -26,6 +26,27 @@ FP16_FUNCS = [
...
@@ -26,6 +26,27 @@ FP16_FUNCS = [
'mv'
,
'mv'
,
]
]
BFLOAT16_FUNCS
=
[
# Low level functions wrapped by torch.nn layers.
# The wrapper layers contain the weights which are then passed in as a parameter
# to these functions.
'conv1d'
,
'conv2d'
,
'conv3d'
,
'conv_transpose1d'
,
'conv_transpose2d'
,
'conv_transpose3d'
,
'conv_tbc'
,
# BLAS
'addmm'
,
'addmv'
,
'addr'
,
'matmul'
,
'mm'
,
'mv'
,
]
FP32_FUNCS
=
[
FP32_FUNCS
=
[
# Pointwise
# Pointwise
'acos'
,
'acos'
,
...
...
apex/amp/utils.py
View file @
c7fd532c
...
@@ -62,6 +62,17 @@ def maybe_half(x, name='', verbose=False):
...
@@ -62,6 +62,17 @@ def maybe_half(x, name='', verbose=False):
print
(
'Float->Half ({})'
.
format
(
name
))
print
(
'Float->Half ({})'
.
format
(
name
))
return
x
.
half
()
return
x
.
half
()
def
maybe_bfloat16
(
x
,
name
=
''
,
verbose
=
False
):
if
is_nested
(
x
):
return
type
(
x
)([
maybe_bfloat16
(
y
)
for
y
in
x
])
if
not
x
.
is_cuda
or
type_string
(
x
)
==
'BFloat16Tensor'
:
return
x
else
:
if
verbose
:
print
(
'Float->BFloat16 ({})'
.
format
(
name
))
return
x
.
bfloat16
()
def
maybe_float
(
x
,
name
=
''
,
verbose
=
False
):
def
maybe_float
(
x
,
name
=
''
,
verbose
=
False
):
if
is_nested
(
x
):
if
is_nested
(
x
):
return
type
(
x
)([
maybe_float
(
y
)
for
y
in
x
])
return
type
(
x
)([
maybe_float
(
y
)
for
y
in
x
])
...
...
apex/amp/wrap.py
View file @
c7fd532c
...
@@ -102,6 +102,8 @@ def promote_match_arg0(mod, fn, handle, verbose=False):
...
@@ -102,6 +102,8 @@ def promote_match_arg0(mod, fn, handle, verbose=False):
if
utils
.
type_string
(
arg0
)
==
'HalfTensor'
:
if
utils
.
type_string
(
arg0
)
==
'HalfTensor'
:
cast_fn
=
utils
.
maybe_half
cast_fn
=
utils
.
maybe_half
if
utils
.
type_string
(
arg0
)
==
'BFloat16Tensor'
:
cast_fn
=
utils
.
maybe_bfloat16
elif
utils
.
type_string
(
arg0
)
==
'FloatTensor'
:
elif
utils
.
type_string
(
arg0
)
==
'FloatTensor'
:
cast_fn
=
utils
.
maybe_float
cast_fn
=
utils
.
maybe_float
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment