Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
5c6144e6
Commit
5c6144e6
authored
Jul 09, 2018
by
Michael Carilli
Browse files
FP16_Optimizer now preserves param order and casts per-param state tensors to FP32
parent
4a8cf7ad
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
51 deletions
+23
-51
apex/fp16_utils/__init__.py
apex/fp16_utils/__init__.py
+1
-1
apex/fp16_utils/fp16_optimizer.py
apex/fp16_utils/fp16_optimizer.py
+22
-50
No files found.
apex/fp16_utils/__init__.py
View file @
5c6144e6
...
@@ -13,7 +13,7 @@ from .fp16util import (
...
@@ -13,7 +13,7 @@ from .fp16util import (
from
.fused_weight_norm
import
Fused_Weight_Norm
from
.fused_weight_norm
import
Fused_Weight_Norm
from
.fp16_optimizer
import
fp32_to_fp16
,
fp16_to_fp32
,
FP16_Module
,
FP16_Optimizer
from
.fp16_optimizer
import
FP16_Optimizer
from
.loss_scaler
import
LossScaler
,
DynamicLossScaler
from
.loss_scaler
import
LossScaler
,
DynamicLossScaler
apex/fp16_utils/fp16_optimizer.py
View file @
5c6144e6
...
@@ -7,48 +7,6 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
...
@@ -7,48 +7,6 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from
.loss_scaler
import
DynamicLossScaler
,
LossScaler
from
.loss_scaler
import
DynamicLossScaler
,
LossScaler
from
.fp16util
import
model_grads_to_master_grads
,
master_params_to_model_params
,
clip_grad_norm
from
.fp16util
import
model_grads_to_master_grads
,
master_params_to_model_params
,
clip_grad_norm
FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
def
conversion_helper
(
val
,
conversion
):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if
not
isinstance
(
val
,
(
tuple
,
list
)):
return
conversion
(
val
)
rtn
=
[
conversion_helper
(
v
,
conversion
)
for
v
in
val
]
if
isinstance
(
val
,
tuple
):
rtn
=
tuple
(
rtn
)
return
rtn
def
fp32_to_fp16
(
val
):
"""Convert fp32 `val` to fp16"""
def
half_conversion
(
val
):
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
FLOAT_TYPES
):
val
=
val
.
half
()
return
val
return
conversion_helper
(
val
,
half_conversion
)
def
fp16_to_fp32
(
val
):
"""Convert fp16 `val` to fp32"""
def
float_conversion
(
val
):
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
HALF_TYPES
):
val
=
val
.
float
()
return
val
return
conversion_helper
(
val
,
float_conversion
)
class
FP16_Module
(
nn
.
Module
):
def
__init__
(
self
,
module
):
super
(
FP16_Module
,
self
).
__init__
()
self
.
add_module
(
'module'
,
module
.
half
())
def
forward
(
self
,
*
inputs
,
**
kwargs
):
return
fp16_to_fp32
(
self
.
module
(
*
(
fp32_to_fp16
(
inputs
)),
**
kwargs
))
# TODO: Update overflow check + downscale to use Carl's fused kernel.
# TODO: Update overflow check + downscale to use Carl's fused kernel.
class
FP16_Optimizer
(
object
):
class
FP16_Optimizer
(
object
):
"""
"""
...
@@ -151,40 +109,54 @@ class FP16_Optimizer(object):
...
@@ -151,40 +109,54 @@ class FP16_Optimizer(object):
if
not
torch
.
cuda
.
is_available
:
if
not
torch
.
cuda
.
is_available
:
raise
SystemError
(
"Cannot use fp16 without CUDA."
)
raise
SystemError
(
"Cannot use fp16 without CUDA."
)
self
.
optimizer
=
init_optimizer
# init_state_dict sets up an alternative way to cast per-param state tensors.
# Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.
# init_state_dict = init_optimizer.state_dict()
self
.
fp16_groups
=
[]
self
.
fp16_groups
=
[]
self
.
fp32_from_fp16_groups
=
[]
self
.
fp32_from_fp16_groups
=
[]
self
.
fp32_from_fp32_groups
=
[]
self
.
fp32_from_fp32_groups
=
[]
for
i
,
param_group
in
enumerate
(
init_
optimizer
.
param_groups
):
for
i
,
param_group
in
enumerate
(
self
.
optimizer
.
param_groups
):
print
(
"FP16_Optimizer processing param group {}:"
.
format
(
i
))
print
(
"FP16_Optimizer processing param group {}:"
.
format
(
i
))
fp16_params_this_group
=
[]
fp16_params_this_group
=
[]
fp32_params_this_group
=
[]
fp32_params_this_group
=
[]
master_params_this_group
=
[]
fp32_from_fp16_params_this_group
=
[]
for
param
in
param_group
[
'params'
]:
for
param
in
param_group
[
'params'
]:
if
param
.
requires_grad
:
if
param
.
requires_grad
:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
print
(
"FP16_Optimizer received torch.cuda.HalfTensor with {}"
print
(
"FP16_Optimizer received torch.cuda.HalfTensor with {}"
.
format
(
param
.
size
()))
.
format
(
param
.
size
()))
fp16_params_this_group
.
append
(
param
)
fp16_params_this_group
.
append
(
param
)
master_param
=
param
.
detach
().
clone
().
float
()
master_param
.
requires_grad
=
True
master_params_this_group
.
append
(
master_param
)
fp32_from_fp16_params_this_group
.
append
(
master_param
)
# Reset existing state dict key to the new master param.
# We still need to recast per-param state tensors, if any, to FP32.
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
master_param
]
=
self
.
optimizer
.
state
.
pop
(
param
)
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
print
(
"FP16_Optimizer received torch.cuda.FloatTensor with {}"
print
(
"FP16_Optimizer received torch.cuda.FloatTensor with {}"
.
format
(
param
.
size
()))
.
format
(
param
.
size
()))
fp32_params_this_group
.
append
(
param
)
fp32_params_this_group
.
append
(
param
)
master_params_this_group
.
append
(
param
)
else
:
else
:
raise
TypeError
(
"Wrapped parameters must be either "
raise
TypeError
(
"Wrapped parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}"
.
format
(
param
.
type
()))
"Received {}"
.
format
(
param
.
type
()))
fp32_from_fp16_params_this_group
=
[
param
.
detach
().
clone
().
float
()
param_group
[
'params'
]
=
master_params_this_group
for
param
in
fp16_params_this_group
]
for
param
in
fp32_from_fp16_params_this_group
:
param
.
requires_grad
=
True
param_group
[
'params'
]
=
fp32_from_fp16_params_this_group
+
fp32_params_this_group
self
.
fp16_groups
.
append
(
fp16_params_this_group
)
self
.
fp16_groups
.
append
(
fp16_params_this_group
)
self
.
fp32_from_fp16_groups
.
append
(
fp32_from_fp16_params_this_group
)
self
.
fp32_from_fp16_groups
.
append
(
fp32_from_fp16_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
self
.
optimizer
=
init_optimizer
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# alternative way to cast per-param state tensors:
# self.optimizer.load_state_dict(init_state_dict)
if
dynamic_loss_scale
:
if
dynamic_loss_scale
:
self
.
dynamic_loss_scale
=
True
self
.
dynamic_loss_scale
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment