Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
9cc74429
Commit
9cc74429
authored
May 24, 2018
by
Carl Case
Browse files
Optimizer wrapper; loss scaling class; no-op handle; start multi-loss
parent
ea93767d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
207 additions
and
36 deletions
+207
-36
.gitignore
.gitignore
+2
-1
apex/amp/__init__.py
apex/amp/__init__.py
+1
-1
apex/amp/amp.py
apex/amp/amp.py
+22
-8
apex/amp/handle.py
apex/amp/handle.py
+44
-26
apex/amp/opt.py
apex/amp/opt.py
+101
-0
apex/amp/scaler.py
apex/amp/scaler.py
+37
-0
No files found.
.gitignore
View file @
9cc74429
apex.egg-info
apex.egg-info
dist
dist
build
build
docs/build
docs/build
\ No newline at end of file
*~
\ No newline at end of file
apex/amp/__init__.py
View file @
9cc74429
from
.amp
import
enable
,
register_half
,
register_float
from
.amp
import
build
,
register_half
,
register_float
,
register_promote
apex/amp/amp.py
View file @
9cc74429
from
.
import
compat
,
utils
,
wrap
from
.
import
compat
,
utils
,
wrap
from
.handle
import
AmpHandle
from
.handle
import
AmpHandle
,
NoOpHandle
from
.lists
import
functional_overrides
,
torch_overrides
,
tensor_overrides
from
.lists
import
functional_overrides
,
torch_overrides
,
tensor_overrides
import
inspect
import
inspect
...
@@ -7,30 +7,44 @@ import itertools
...
@@ -7,30 +7,44 @@ import itertools
import
torch
import
torch
_USER_REGISTRY
=
set
()
_USER_CAST_REGISTRY
=
set
()
_USER_PROMOTE_REGISTRY
=
set
()
# Can be used as a @decorator directly on the fn
# Can be used as a @decorator directly on the fn
# or called w/ arg by user before `
enable
()`
# or called w/ arg by user before `
build
()`
def
register_half
(
fn
):
def
register_half
(
fn
):
mod
=
inspect
.
getmodule
(
fn
)
mod
=
inspect
.
getmodule
(
fn
)
_USER_REGISTRY
.
add
((
mod
,
fn
.
__name__
,
utils
.
maybe_half
))
_USER_
CAST_
REGISTRY
.
add
((
mod
,
fn
.
__name__
,
utils
.
maybe_half
))
return
fn
return
fn
def
register_float
(
fn
):
def
register_float
(
fn
):
mod
=
inspect
.
getmodule
(
fn
)
mod
=
inspect
.
getmodule
(
fn
)
_USER_REGISTRY
.
add
((
mod
,
fn
.
__name__
,
utils
.
maybe_float
))
_USER_CAST_REGISTRY
.
add
((
mod
,
fn
.
__name__
,
utils
.
maybe_float
))
return
fn
def
register_promote
(
fn
):
mod
=
inspect
.
getmodule
(
fn
)
_USER_PROMOTE_REGISTRY
.
add
((
mod
,
fn
.
__name__
))
return
fn
return
fn
# Top-level function to insert _all_ the hooks.
# Top-level function to insert _all_ the hooks.
def
enable
(
enable_caching
=
True
,
verbose
=
False
):
def
build
(
enabled
=
True
,
enable_caching
=
True
,
verbose
=
False
):
if
not
enabled
:
return
NoOpHandle
()
handle
=
AmpHandle
(
enable_caching
)
handle
=
AmpHandle
(
enable_caching
)
# 0) Force-{fp16, fp32} for user-annotated functions
# 0) Force-{fp16, fp32} for user-annotated functions
for
mod
,
fn
,
cast_fn
in
_USER_REGISTRY
:
for
mod
,
fn
,
cast_fn
in
_USER_
CAST_
REGISTRY
:
try_caching
=
(
cast_fn
==
utils
.
maybe_half
)
try_caching
=
(
cast_fn
==
utils
.
maybe_half
)
wrap
.
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
wrap
.
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
try_caching
,
verbose
)
try_caching
,
verbose
)
_USER_REGISTRY
.
clear
()
_USER_CAST_REGISTRY
.
clear
()
# 0.5) Force-promote for user-annotated functions
for
mod
,
fn
in
_USER_PROMOTE_REGISTRY
:
wrap
.
promote
(
mod
,
fn
,
verbose
)
_USER_PROMOTE_REGISTRY
.
clear
()
# 1) Force-{fp16, fp32} on white- / black-list functions
# 1) Force-{fp16, fp32} on white- / black-list functions
override_modules
=
[
functional_overrides
,
override_modules
=
[
functional_overrides
,
...
...
apex/amp/handle.py
View file @
9cc74429
...
@@ -2,54 +2,53 @@ import contextlib
...
@@ -2,54 +2,53 @@ import contextlib
import
logging
import
logging
import
warnings
import
warnings
import
torch
from
.opt
import
OptimWrapper
from
.scaler
import
LossScaler
from
._C
import
scale_lib
class
AmpHandle
(
object
):
class
AmpHandle
(
object
):
def
__init__
(
self
,
enable_caching
=
True
):
def
__init__
(
self
,
enable_caching
=
True
):
self
.
_enable_caching
=
enable_caching
self
.
_enable_caching
=
enable_caching
self
.
_cache
=
dict
()
self
.
_cache
=
dict
()
self
.
_loss_scale
=
2.
**
16
self
.
_default_scaler
=
LossScaler
()
self
.
_max_loss_scale
=
2.
**
24
self
.
_scale_seq_len
=
2000
def
is_active
(
self
):
self
.
_unskipped
=
0
return
True
self
.
_overflow_buf
=
torch
.
cuda
.
ByteTensor
(
1024
,)
def
wrap_optimizer
(
self
,
optimizer
,
num_loss
=
1
):
self
.
_default_scaler
=
None
return
OptimWrapper
(
optimizer
,
self
,
num_loss
)
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
scale_loss
(
self
,
loss
,
optimizer
):
def
scale_loss
(
self
,
loss
,
optimizer
):
if
not
self
.
is_active
():
yield
loss
return
if
self
.
_default_scaler
is
None
:
raise
RuntimeError
(
'After calling `handle.wrap_optimizer()`, you must explicitly '
+
'use `optimizer.scale_loss(loss)`.'
)
# TODO: this code block is duplicated here and `opt.py`. Unify.
loss_backward
=
loss
.
backward
loss_backward
=
loss
.
backward
def
warning_wrapper
():
def
warning_wrapper
():
warnings
.
warn
(
"You called .backward() on the unscaled loss "
warnings
.
warn
(
"You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"inside a scale_loss block. This is almost "
"certainly an error."
,
stacklevel
=
2
)
"certainly an error."
,
stacklevel
=
2
)
loss_backward
()
loss_backward
()
loss
.
backward
=
warning_wrapper
loss
.
backward
=
warning_wrapper
yield
loss
*
self
.
_loss_scale
loss_scale
=
self
.
_default_scaler
.
loss_scale
()
yield
loss
*
loss_scale
loss
.
backward
=
loss_backward
loss
.
backward
=
loss_backward
self
.
_overflow_buf
.
zero_
()
should_skip
=
self
.
_default_scaler
.
unscale_and_update
(
for
group
in
optimizer
.
param_groups
:
optimizer
.
param_groups
,
loss_scale
)
for
p
in
group
[
'params'
]:
if
should_skip
:
if
p
.
grad
is
not
None
:
scale_lib
.
scale_check_overflow
(
p
.
grad
.
data
,
1.
/
self
.
_loss_scale
,
self
.
_overflow_buf
)
if
self
.
_overflow_buf
.
any
():
self
.
_loss_scale
/=
2.
optimizer_step
=
optimizer
.
step
optimizer_step
=
optimizer
.
step
def
skip_step
():
def
skip_step
():
logging
.
info
(
'Gradient overflow, skipping update'
)
logging
.
info
(
'Gradient overflow, skipping update'
)
optimizer
.
step
=
optimizer_step
optimizer
.
step
=
optimizer_step
optimizer
.
step
=
skip_step
optimizer
.
step
=
skip_step
self
.
_unskipped
=
0
else
:
self
.
_unskipped
+=
1
if
self
.
_unskipped
==
self
.
_scale_seq_len
:
self
.
_loss_scale
=
min
(
self
.
_max_loss_scale
,
self
.
_loss_scale
*
2.
)
self
.
_unskipped
=
0
self
.
_clear_cache
()
self
.
_clear_cache
()
...
@@ -63,3 +62,22 @@ class AmpHandle(object):
...
@@ -63,3 +62,22 @@ class AmpHandle(object):
@
property
@
property
def
cache
(
self
):
def
cache
(
self
):
return
self
.
_cache
return
self
.
_cache
def
remove_cache
(
self
,
param
):
if
self
.
has_cache
and
param
in
self
.
cache
:
del
self
.
cache
[
param
]
class
NoOpHandle
(
object
):
def
is_active
(
self
):
return
False
def
wrap_optimizer
(
self
,
optimizer
,
num_loss
=
1
):
return
OptimWrapper
(
optimizer
,
self
,
num_loss
)
@
contextlib
.
contextmanager
def
scale_loss
(
self
,
loss
,
optimizer
):
yield
loss
@
property
def
has_cache
(
self
):
return
False
apex/amp/opt.py
0 → 100644
View file @
9cc74429
import
contextlib
import
warnings
from
.scaler
import
LossScaler
import
numpy
as
np
class
OptimWrapper
(
object
):
def
__init__
(
self
,
optimizer
,
amp_handle
,
num_loss
):
self
.
_optimizer
=
optimizer
self
.
_amp_handle
=
amp_handle
self
.
_num_loss
=
num_loss
self
.
_loss_idx
=
0
self
.
_skip_next
=
[
False
]
*
num_loss
self
.
_loss_scaler
=
[
LossScaler
()
for
_
in
range
(
num_loss
)]
@
contextlib
.
contextmanager
def
scale_loss
(
self
,
loss
):
if
not
self
.
_amp_handle
.
is_active
():
yield
loss
return
loss_backward
=
loss
.
backward
def
warning_wrapper
():
warnings
.
warn
(
"You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error."
,
stacklevel
=
2
)
loss_backward
()
loss
.
backward
=
warning_wrapper
# if loss_idx > 0:
# save out current grads to buffers
# keep some group caches
# .detach().clone()
# zero grads
loss_scale
=
self
.
_cur_loss_scaler
().
loss_scale
()
print
(
'Loss scale (log): {}'
.
format
(
np
.
log2
(
loss_scale
)))
yield
loss
*
loss_scale
loss
.
backward
=
loss_backward
self
.
_skip_next
[
self
.
_loss_idx
]
=
self
.
_cur_loss_scaler
().
unscale_and_update
(
self
.
_optimizer
.
param_groups
,
loss_scale
)
print
(
'GOT SKIP NEXT: {}'
.
format
(
self
.
_skip_next
[
self
.
_loss_idx
]))
self
.
_loss_idx
+=
1
# if loss_idx > 0:
# += saved state into grads
def
_cur_loss_scaler
(
self
):
assert
0
<=
self
.
_loss_idx
<
self
.
_num_loss
return
self
.
_loss_scaler
[
self
.
_loss_idx
]
def
step
(
self
,
closure
=
None
):
if
not
self
.
_amp_handle
.
is_active
():
return
self
.
_optimizer
.
step
(
closure
=
closure
)
self
.
_loss_idx
=
0
for
group
in
self
.
_optimizer
.
param_groups
:
for
p
in
group
[
'params'
]:
self
.
_amp_handle
.
remove_cache
(
p
)
if
closure
is
not
None
:
raise
NotImplementedError
(
'The `closure` argument is unsupported by the amp '
+
'optimizer wrapper.'
)
if
any
(
self
.
_skip_next
):
self
.
_skip_next
=
[
False
]
*
self
.
_num_loss
print
(
'SKIP'
)
else
:
return
self
.
_optimizer
.
step
(
closure
=
closure
)
# Forward any attribute lookups
def
__getattr__
(
self
,
attr
):
return
getattr
(
self
.
_optimizer
,
attr
)
# Forward all torch.optim.Optimizer methods
def
__getstate__
(
self
):
return
self
.
_optimizer
.
__getstate__
()
def
__setstate__
(
self
):
return
self
.
_optimizer
.
__setstate__
()
def
__repr__
(
self
):
return
self
.
_optimizer
.
__repr__
()
def
state_dict
(
self
):
return
self
.
_optimizer
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
):
return
self
.
_optimizer
.
load_state_dict
(
state_dict
)
def
zero_grad
(
self
):
return
self
.
_optimizer
.
zero_grad
()
def
add_param_group
(
self
,
param_group
):
return
self
.
_optimizer
.
add_param_group
(
param_group
)
apex/amp/scaler.py
0 → 100644
View file @
9cc74429
import
torch
from
._C
import
scale_lib
class
LossScaler
(
object
):
def
__init__
(
self
):
self
.
_loss_scale
=
2.
**
16
self
.
_max_loss_scale
=
2.
**
24
self
.
_scale_seq_len
=
2000
self
.
_unskipped
=
0
self
.
_overflow_buf
=
torch
.
cuda
.
ByteTensor
(
1024
,)
def
loss_scale
(
self
):
return
self
.
_loss_scale
def
unscale_and_update
(
self
,
param_groups
,
scale
):
self
.
_overflow_buf
.
zero_
()
for
group
in
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
not
None
:
scale_lib
.
scale_check_overflow
(
p
.
grad
.
data
,
1.
/
scale
,
self
.
_overflow_buf
)
if
self
.
_overflow_buf
.
any
():
should_skip
=
True
self
.
_loss_scale
/=
2.
self
.
_unskipped
=
0
else
:
should_skip
=
False
self
.
_unskipped
+=
1
if
self
.
_unskipped
==
self
.
_scale_seq_len
:
self
.
_loss_scale
=
min
(
self
.
_max_loss_scale
,
self
.
_loss_scale
*
2.
)
self
.
_unskipped
=
0
return
should_skip
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment