Commit 4b9858ec authored by Michael Carilli's avatar Michael Carilli
Browse files

Don't need to blacklist mean for pytorch >= 1.1

parent 90e5b05a
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch
from . import parallel
from . import amp
from . import fp16_utils
......
......@@ -28,7 +28,7 @@ FP16_FUNCS = [
FP32_FUNCS = [
# Interpolation/Upsampling
# Interpolation/Upsampling TODO: Remove for 1.2
'interpolate',
# Pointwise
......
......@@ -5,10 +5,10 @@ import importlib
import torch
if compat.variable_is_tensor() and not compat.tensor_is_variable():
MODULE = torch.Tensor
else:
MODULE = torch.autograd.Variable
# if compat.variable_is_tensor() and not compat.tensor_is_variable():
MODULE = torch.Tensor
# else:
# MODULE = torch.autograd.Variable
FP16_FUNCS = [
......
......@@ -49,7 +49,7 @@ FP32_FUNCS = [
'cumprod',
'cumsum',
'dist',
'mean',
# 'mean',
'norm',
'prod',
'std',
......@@ -60,6 +60,14 @@ FP32_FUNCS = [
'renorm'
]
version_strings = torch.__version__.split('.')
version_major = version_strings[0]
version_minor = version_strings[1]
version_num = float(version_major + "." + version_minor)
# Before torch 1.1, mean must be blacklisted.
if version_num < 1.1:
FP32_FUNCS.append('mean')
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
# check the CUDA version -- if at least 9.1, then put the bmm
# functions on the fp16 list. Otherwise, put them on the fp32 list.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment