Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
03421e87
Commit
03421e87
authored
Oct 01, 2019
by
Timothee Cour
Committed by
mcarilli
Oct 01, 2019
Browse files
fix
https://github.com/facebookresearch/maskrcnn-benchmark/issues/802
(#516)
parent
3ae89c75
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
4 deletions
+10
-4
apex/amp/lists/torch_overrides.py
apex/amp/lists/torch_overrides.py
+7
-4
apex/amp/utils.py
apex/amp/utils.py
+3
-0
No files found.
apex/amp/lists/torch_overrides.py
View file @
03421e87
...
@@ -74,9 +74,12 @@ if version_num < 1.1:
...
@@ -74,9 +74,12 @@ if version_num < 1.1:
_bmms
=
[
'addbmm'
,
_bmms
=
[
'addbmm'
,
'baddbmm'
,
'baddbmm'
,
'bmm'
]
'bmm'
]
if
utils
.
get_cuda_version
()
>=
(
9
,
1
,
0
):
if
utils
.
is_cuda_enabled
():
# workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802
if
utils
.
get_cuda_version
()
>=
(
9
,
1
,
0
):
FP16_FUNCS
.
extend
(
_bmms
)
FP16_FUNCS
.
extend
(
_bmms
)
else
:
else
:
FP32_FUNCS
.
extend
(
_bmms
)
FP32_FUNCS
.
extend
(
_bmms
)
# Multi-tensor fns that may need type promotion
# Multi-tensor fns that may need type promotion
...
...
apex/amp/utils.py
View file @
03421e87
...
@@ -5,6 +5,9 @@ import itertools
...
@@ -5,6 +5,9 @@ import itertools
import
torch
import
torch
def
is_cuda_enabled
():
return
torch
.
version
.
cuda
is
not
None
def
get_cuda_version
():
def
get_cuda_version
():
return
tuple
(
int
(
x
)
for
x
in
torch
.
version
.
cuda
.
split
(
'.'
))
return
tuple
(
int
(
x
)
for
x
in
torch
.
version
.
cuda
.
split
(
'.'
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment