Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
614b11ff
Commit
614b11ff
authored
May 29, 2018
by
Carl Case
Browse files
support multi-loss scaling per-optimizer correctly
parent
8be1b053
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
20 deletions
+31
-20
apex/amp/opt.py
apex/amp/opt.py
+21
-14
apex/amp/scaler.py
apex/amp/scaler.py
+10
-6
No files found.
apex/amp/opt.py
View file @
614b11ff
import
contextlib
import
contextlib
import
logging
import
warnings
import
warnings
from
.scaler
import
LossScaler
from
.scaler
import
LossScaler
,
iter_params
import
numpy
as
np
import
numpy
as
np
...
@@ -28,27 +29,33 @@ class OptimWrapper(object):
...
@@ -28,27 +29,33 @@ class OptimWrapper(object):
loss_backward
()
loss_backward
()
loss
.
backward
=
warning_wrapper
loss
.
backward
=
warning_wrapper
# When there are multiple losses per-optimizer, we need
# if loss_idx > 0:
# to save out current grad accumulation, since we won't be
# save out current grads to buffers
# able to unscale this particulare loss once the grads are
# keep some group caches
# all mixed together.
# .detach().clone()
cached_grads
=
[]
# zero grads
if
self
.
_loss_idx
>
0
:
for
p
in
iter_params
(
self
.
_optimizer
.
param_groups
):
if
p
.
grad
is
not
None
:
cached_grads
.
append
(
p
.
grad
.
data
.
detach
().
clone
())
else
:
cached_grads
.
append
(
None
)
self
.
_optimizer
.
zero_grad
()
loss_scale
=
self
.
_cur_loss_scaler
().
loss_scale
()
loss_scale
=
self
.
_cur_loss_scaler
().
loss_scale
()
print
(
'Loss scale (log): {}'
.
format
(
np
.
log2
(
loss_scale
)))
yield
loss
*
loss_scale
yield
loss
*
loss_scale
loss
.
backward
=
loss_backward
loss
.
backward
=
loss_backward
self
.
_skip_next
[
self
.
_loss_idx
]
=
self
.
_cur_loss_scaler
().
unscale_and_update
(
self
.
_skip_next
[
self
.
_loss_idx
]
=
self
.
_cur_loss_scaler
().
unscale_and_update
(
self
.
_optimizer
.
param_groups
,
loss_scale
)
self
.
_optimizer
.
param_groups
,
loss_scale
)
print
(
'GOT SKIP NEXT: {}'
.
format
(
self
.
_skip_next
[
self
.
_loss_idx
]))
self
.
_loss_idx
+=
1
self
.
_loss_idx
+=
1
# if loss_idx > 0:
if
len
(
cached_grads
)
>
0
:
# += saved state into grads
for
p
,
cached_grad
in
zip
(
iter_params
(
self
.
_optimizer
.
param_groups
),
cached_grads
):
if
cached_grad
is
not
None
:
p
.
grad
.
data
.
add_
(
cached_grad
)
cached_grads
=
[]
def
_cur_loss_scaler
(
self
):
def
_cur_loss_scaler
(
self
):
assert
0
<=
self
.
_loss_idx
<
self
.
_num_loss
assert
0
<=
self
.
_loss_idx
<
self
.
_num_loss
...
@@ -69,8 +76,8 @@ class OptimWrapper(object):
...
@@ -69,8 +76,8 @@ class OptimWrapper(object):
'The `closure` argument is unsupported by the amp '
+
'The `closure` argument is unsupported by the amp '
+
'optimizer wrapper.'
)
'optimizer wrapper.'
)
if
any
(
self
.
_skip_next
):
if
any
(
self
.
_skip_next
):
logging
.
info
(
'Gradient overflow, skipping update'
)
self
.
_skip_next
=
[
False
]
*
self
.
_num_loss
self
.
_skip_next
=
[
False
]
*
self
.
_num_loss
print
(
'SKIP'
)
else
:
else
:
return
self
.
_optimizer
.
step
(
closure
=
closure
)
return
self
.
_optimizer
.
step
(
closure
=
closure
)
...
...
apex/amp/scaler.py
View file @
614b11ff
...
@@ -15,12 +15,11 @@ class LossScaler(object):
...
@@ -15,12 +15,11 @@ class LossScaler(object):
def
unscale_and_update
(
self
,
param_groups
,
scale
):
def
unscale_and_update
(
self
,
param_groups
,
scale
):
self
.
_overflow_buf
.
zero_
()
self
.
_overflow_buf
.
zero_
()
for
group
in
param_groups
:
for
p
in
iter_params
(
param_groups
):
for
p
in
group
[
'params'
]:
if
p
.
grad
is
not
None
:
if
p
.
grad
is
not
None
:
scale_lib
.
scale_check_overflow
(
p
.
grad
.
data
,
scale_lib
.
scale_check_overflow
(
p
.
grad
.
data
,
1.
/
scale
,
1.
/
scale
,
self
.
_overflow_buf
)
self
.
_overflow_buf
)
if
self
.
_overflow_buf
.
any
():
if
self
.
_overflow_buf
.
any
():
should_skip
=
True
should_skip
=
True
...
@@ -35,3 +34,8 @@ class LossScaler(object):
...
@@ -35,3 +34,8 @@ class LossScaler(object):
self
.
_unskipped
=
0
self
.
_unskipped
=
0
return
should_skip
return
should_skip
def
iter_params
(
param_groups
):
for
group
in
param_groups
:
for
p
in
group
[
'params'
]:
yield
p
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment