Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
ec431d33
Commit
ec431d33
authored
May 30, 2018
by
Carl Case
Browse files
WIP: better annotation / user function registry support
parent
614b11ff
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
83 additions
and
32 deletions
+83
-32
apex/amp/__init__.py
apex/amp/__init__.py
+2
-1
apex/amp/amp.py
apex/amp/amp.py
+49
-18
apex/amp/handle.py
apex/amp/handle.py
+10
-1
apex/amp/wrap.py
apex/amp/wrap.py
+22
-12
No files found.
apex/amp/__init__.py
View file @
ec431d33
from
.amp
import
build
,
register_half
,
register_float
,
register_promote
from
.amp
import
build
,
half_function
,
float_function
,
promote_function
,
\
register_half_function
,
register_float_function
,
register_promote_function
apex/amp/amp.py
View file @
ec431d33
...
...
@@ -2,37 +2,67 @@ from . import compat, utils, wrap
from
.handle
import
AmpHandle
,
NoOpHandle
from
.lists
import
functional_overrides
,
torch_overrides
,
tensor_overrides
import
inspect
import
functools
import
itertools
import
torch
_DECORATOR_HANDLE
=
None
_USER_CAST_REGISTRY
=
set
()
_USER_PROMOTE_REGISTRY
=
set
()
# Can be used as a @decorator directly on the fn
# or called w/ arg by user before `build()`
def
register_half
(
fn
):
mod
=
inspect
.
getmodule
(
fn
)
_USER_CAST_REGISTRY
.
add
((
mod
,
fn
.
__name__
,
utils
.
maybe_half
))
return
fn
def
register_float
(
fn
):
mod
=
inspect
.
getmodule
(
fn
)
_USER_CAST_REGISTRY
.
add
((
mod
,
fn
.
__name__
,
utils
.
maybe_float
))
return
fn
def
register_promote
(
fn
):
mod
=
inspect
.
getmodule
(
fn
)
def
_decorator_helper
(
orig_fn
,
cast_fn
,
wrap_fn
):
def
wrapper
(
*
args
,
**
kwargs
):
handle
=
_DECORATOR_HANDLE
if
handle
is
None
or
not
handle
.
is_active
():
return
orig_fn
(
*
args
,
**
kwargs
)
inner_cast_fn
=
utils
.
verbosify
(
cast_fn
,
orig_fn
.
__name__
,
handle
.
verbose
)
return
wrap_fn
(
orig_fn
,
inner_cast_fn
,
handle
)(
*
args
,
**
kwargs
)
return
wrapper
# Decorator form
def
half_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
True
)
return
_decorator_helper
(
fn
,
utils
.
maybe_half
,
wrap_fn
)
def
float_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
False
)
return
_decorator_helper
(
fn
,
utils
.
maybe_float
,
wrap_fn
)
def
promote_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_promote_wrapper
)
return
_decorator_helper
(
fn
,
utils
.
maybe_float
,
wrap_fn
)
# Registry form
def
register_half_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
raise
ValueError
(
'No function named {} in module {}.'
.
format
(
name
,
module
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_half
))
def
register_float_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
raise
ValueError
(
'No function named {} in module {}.'
.
format
(
name
,
module
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_float
))
def
register_promote_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
raise
ValueError
(
'No function named {} in module {}.'
.
format
(
name
,
module
))
_USER_PROMOTE_REGISTRY
.
add
((
mod
,
fn
.
__name__
))
return
fn
# Top-level function to insert _all_ the hooks.
def
build
(
enabled
=
True
,
enable_caching
=
True
,
verbose
=
False
):
global
_DECORATOR_HANDLE
if
not
enabled
:
return
NoOpHandle
()
handle
=
NoOpHandle
()
_DECORATOR_HANDLE
=
handle
return
handle
handle
=
AmpHandle
(
enable_caching
)
handle
=
AmpHandle
(
enable_caching
,
verbose
)
# 0) Force-{fp16, fp32} for user-annotated functions
for
mod
,
fn
,
cast_fn
in
_USER_CAST_REGISTRY
:
...
...
@@ -115,4 +145,5 @@ def build(enabled=True, enable_caching=True, verbose=False):
# 5.5) Extra-special handling of RNN backend
wrap
.
rnn_cast
(
torch
.
nn
.
backends
.
thnn
.
backend
,
'RNN'
,
verbose
)
_DECORATOR_HANDLE
=
handle
return
handle
apex/amp/handle.py
View file @
ec431d33
...
...
@@ -6,8 +6,9 @@ from .opt import OptimWrapper
from
.scaler
import
LossScaler
class
AmpHandle
(
object
):
def
__init__
(
self
,
enable_caching
=
True
):
def
__init__
(
self
,
enable_caching
=
True
,
verbose
=
False
):
self
.
_enable_caching
=
enable_caching
self
.
_verbose
=
verbose
self
.
_cache
=
dict
()
self
.
_default_scaler
=
LossScaler
()
...
...
@@ -67,6 +68,10 @@ class AmpHandle(object):
if
self
.
has_cache
and
param
in
self
.
cache
:
del
self
.
cache
[
param
]
@
property
def
verbose
(
self
):
return
self
.
_verbose
class
NoOpHandle
(
object
):
def
is_active
(
self
):
return
False
...
...
@@ -81,3 +86,7 @@ class NoOpHandle(object):
@
property
def
has_cache
(
self
):
return
False
@
property
def
verbose
(
self
):
return
False
apex/amp/wrap.py
View file @
ec431d33
...
...
@@ -5,13 +5,8 @@ import functools
import
torch
def
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
try_caching
=
False
,
verbose
=
False
):
if
not
utils
.
has_func
(
mod
,
fn
):
return
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
cast_fn
=
utils
.
verbosify
(
cast_fn
,
fn
,
verbose
)
def
make_cast_wrapper
(
orig_fn
,
cast_fn
,
handle
,
try_caching
=
False
):
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
try_caching
and
handle
.
has_cache
:
...
...
@@ -26,18 +21,27 @@ def cached_cast(mod, fn, cast_fn, handle,
args
,
kwargs
)
return
orig_fn
(
*
new_args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
return
wrapper
def
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
try_caching
=
False
,
verbose
=
False
):
if
not
utils
.
has_func
(
mod
,
fn
):
return
def
promote
(
mod
,
fn
,
verbose
=
False
):
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
maybe_float
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
cast_fn
=
utils
.
verbosify
(
cast_fn
,
fn
,
verbose
)
wrapper
=
make_cast_wrapper
(
orig_fn
,
cast_fn
,
handle
,
try_caching
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
def
make_promote_wrapper
(
orig_fn
,
cast_fn
,
handle
=
None
):
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
types
=
utils
.
collect_fp_tensor_types
(
args
,
kwargs
)
if
len
(
types
)
<=
1
:
return
orig_fn
(
*
args
,
**
kwargs
)
elif
len
(
types
)
==
2
and
types
==
set
([
'HalfTensor'
,
'FloatTensor'
]):
new_args
=
utils
.
casted_args
(
maybe_float
,
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
kwargs
)
return
orig_fn
(
*
new_args
,
**
kwargs
)
...
...
@@ -45,8 +49,14 @@ def promote(mod, fn, verbose=False):
raise
NotImplementedError
(
'Do not know how to handle '
+
'these types to promote: {}'
.
format
(
types
))
return
wrapper
def
promote
(
mod
,
fn
,
verbose
=
False
):
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
maybe_float
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
wrapper
=
make_promote_wrapper
(
orig_fn
,
maybe_float
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
def
sequence_promote
(
mod
,
fn
,
verbose
=
False
):
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
maybe_float
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment