Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
ae921de2
Commit
ae921de2
authored
Jul 24, 2018
by
Michael Carilli
Browse files
Fixing FP16_Optimizer handling of LBFGS
parent
d695b68b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
6 deletions
+3
-6
apex/fp16_utils/fp16_optimizer.py
apex/fp16_utils/fp16_optimizer.py
+3
-6
No files found.
apex/fp16_utils/fp16_optimizer.py
View file @
ae921de2
...
...
@@ -121,9 +121,8 @@ class FP16_Optimizer(object):
print
(
"FP16_Optimizer processing param group {}:"
.
format
(
i
))
fp16_params_this_group
=
[]
fp32_params_this_group
=
[]
master_params_this_group
=
[]
fp32_from_fp16_params_this_group
=
[]
for
param
in
param_group
[
'params'
]:
for
i
,
param
in
enumerate
(
param_group
[
'params'
]
)
:
if
param
.
requires_grad
:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
print
(
"FP16_Optimizer received torch.cuda.HalfTensor with {}"
...
...
@@ -131,7 +130,7 @@ class FP16_Optimizer(object):
fp16_params_this_group
.
append
(
param
)
master_param
=
param
.
detach
().
clone
().
float
()
master_param
.
requires_grad
=
True
master_params_this_group
.
append
(
master_param
)
param_group
[
'params'
][
i
]
=
master_param
fp32_from_fp16_params_this_group
.
append
(
master_param
)
# Reset existing state dict key to the new master param.
# We still need to recast per-param state tensors, if any, to FP32.
...
...
@@ -141,14 +140,12 @@ class FP16_Optimizer(object):
print
(
"FP16_Optimizer received torch.cuda.FloatTensor with {}"
.
format
(
param
.
size
()))
fp32_params_this_group
.
append
(
param
)
master_params_this_group
.
append
(
param
)
param_group
[
'params'
][
i
]
=
param
else
:
raise
TypeError
(
"Wrapped parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}"
.
format
(
param
.
type
()))
param_group
[
'params'
]
=
master_params_this_group
self
.
fp16_groups
.
append
(
fp16_params_this_group
)
self
.
fp32_from_fp16_groups
.
append
(
fp32_from_fp16_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment