Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
f1123e32
".github/vscode:/vscode.git/clone" did not exist on "191d94289d016b59c0553b14d299d1bac07a7fcd"
Commit
f1123e32
authored
Mar 27, 2019
by
Carl Case
Browse files
Conditionally run bmm functions in fp16 based on cuda version
parent
f5cd5ae9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
10 deletions
+20
-10
apex/amp/lists/torch_overrides.py
apex/amp/lists/torch_overrides.py
+17
-10
apex/amp/utils.py
apex/amp/utils.py
+3
-0
No files found.
apex/amp/lists/torch_overrides.py
View file @
f1123e32
import
torch
from
..
import
utils
MODULE
=
torch
FP16_FUNCS
=
[
...
...
@@ -20,10 +22,8 @@ FP16_FUNCS = [
'matmul'
,
'mm'
,
'mv'
,
]
# TODO: ban in-place versions of these in fp16
FP32_FUNCS
=
[
# Pointwise
'acos'
,
...
...
@@ -54,15 +54,21 @@ FP32_FUNCS = [
'sum'
,
'var'
,
# Special reduction-like BLAS
'addbmm'
,
'baddbmm'
,
'bmm'
,
# Misc
'renorm'
]
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
# check the CUDA version -- if at least 9.1, then put the bmm
# functions on the fp16 list. Otherwise, put them on the fp32 list.
_bmms
=
[
'addbmm'
,
'baddbmm'
,
'bmm'
]
if
utils
.
get_cuda_version
()
>=
(
9
,
1
,
0
):
FP16_FUNCS
.
extend
(
_bmms
)
else
:
FP32_FUNCS
.
extend
(
_bmms
)
# Multi-tensor fns that may need type promotion
CASTS
=
[
# Multi-tensor math
...
...
@@ -87,8 +93,9 @@ CASTS = [
'ne'
]
# Will possibly need to promote *all* elements of `seq`
# Functions that take sequence arguments. We need to inspect the whole
# sequence and cast to the widest type.
SEQUENCE_CASTS
=
[
'cat'
,
# torch.cat(seq, dim=0, out=None)
'stack'
# torch.stack(seq, dim=0, out=None)
'cat'
,
'stack'
]
apex/amp/utils.py
View file @
f1123e32
...
...
@@ -5,6 +5,9 @@ import itertools
import
torch
def
get_cuda_version
():
return
tuple
(
int
(
x
)
for
x
in
torch
.
version
.
cuda
.
split
(
'.'
))
def
is_fp_tensor
(
x
):
if
is_nested
(
x
):
# Fast-fail version of all(is_fp_tensor)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment