Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
4b9858ec
Commit
4b9858ec
authored
Jun 21, 2019
by
Michael Carilli
Browse files
Don't need to blacklist mean for pytorch >= 1.1
parent
90e5b05a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
6 deletions
+17
-6
apex/__init__.py
apex/__init__.py
+3
-0
apex/amp/lists/functional_overrides.py
apex/amp/lists/functional_overrides.py
+1
-1
apex/amp/lists/tensor_overrides.py
apex/amp/lists/tensor_overrides.py
+4
-4
apex/amp/lists/torch_overrides.py
apex/amp/lists/torch_overrides.py
+9
-1
No files found.
apex/__init__.py
View file @
4b9858ec
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import
torch
from
.
import
parallel
from
.
import
amp
from
.
import
fp16_utils
...
...
apex/amp/lists/functional_overrides.py
View file @
4b9858ec
...
...
@@ -28,7 +28,7 @@ FP16_FUNCS = [
FP32_FUNCS
=
[
# Interpolation/Upsampling
# Interpolation/Upsampling
TODO: Remove for 1.2
'interpolate'
,
# Pointwise
...
...
apex/amp/lists/tensor_overrides.py
View file @
4b9858ec
...
...
@@ -5,10 +5,10 @@ import importlib
import
torch
if
compat
.
variable_is_tensor
()
and
not
compat
.
tensor_is_variable
():
MODULE
=
torch
.
Tensor
else
:
MODULE
=
torch
.
autograd
.
Variable
#
if compat.variable_is_tensor() and not compat.tensor_is_variable():
MODULE
=
torch
.
Tensor
#
else:
#
MODULE = torch.autograd.Variable
FP16_FUNCS
=
[
...
...
apex/amp/lists/torch_overrides.py
View file @
4b9858ec
...
...
@@ -49,7 +49,7 @@ FP32_FUNCS = [
'cumprod'
,
'cumsum'
,
'dist'
,
'mean'
,
#
'mean',
'norm'
,
'prod'
,
'std'
,
...
...
@@ -60,6 +60,14 @@ FP32_FUNCS = [
'renorm'
]
version_strings
=
torch
.
__version__
.
split
(
'.'
)
version_major
=
version_strings
[
0
]
version_minor
=
version_strings
[
1
]
version_num
=
float
(
version_major
+
"."
+
version_minor
)
# Before torch 1.1, mean must be blacklisted.
if
version_num
<
1.1
:
FP32_FUNCS
.
append
(
'mean'
)
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
# check the CUDA version -- if at least 9.1, then put the bmm
# functions on the fp16 list. Otherwise, put them on the fp32 list.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment