"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3b2830618ddff967a1f3a1307a15e24a75c7ae6e"
Commit 4b9858ec authored by Michael Carilli's avatar Michael Carilli
Browse files

Don't need to blacklist mean for pytorch >= 1.1

parent 90e5b05a
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch
from . import parallel from . import parallel
from . import amp from . import amp
from . import fp16_utils from . import fp16_utils
......
...@@ -28,7 +28,7 @@ FP16_FUNCS = [ ...@@ -28,7 +28,7 @@ FP16_FUNCS = [
FP32_FUNCS = [ FP32_FUNCS = [
# Interpolation/Upsampling # Interpolation/Upsampling TODO: Remove for 1.2
'interpolate', 'interpolate',
# Pointwise # Pointwise
......
...@@ -5,10 +5,10 @@ import importlib ...@@ -5,10 +5,10 @@ import importlib
import torch import torch
if compat.variable_is_tensor() and not compat.tensor_is_variable(): # if compat.variable_is_tensor() and not compat.tensor_is_variable():
MODULE = torch.Tensor MODULE = torch.Tensor
else: # else:
MODULE = torch.autograd.Variable # MODULE = torch.autograd.Variable
FP16_FUNCS = [ FP16_FUNCS = [
......
...@@ -49,7 +49,7 @@ FP32_FUNCS = [ ...@@ -49,7 +49,7 @@ FP32_FUNCS = [
'cumprod', 'cumprod',
'cumsum', 'cumsum',
'dist', 'dist',
'mean', # 'mean',
'norm', 'norm',
'prod', 'prod',
'std', 'std',
...@@ -60,6 +60,14 @@ FP32_FUNCS = [ ...@@ -60,6 +60,14 @@ FP32_FUNCS = [
'renorm' 'renorm'
] ]
version_strings = torch.__version__.split('.')
version_major = version_strings[0]
version_minor = version_strings[1]
version_num = float(version_major + "." + version_minor)
# Before torch 1.1, mean must be blacklisted.
if version_num < 1.1:
FP32_FUNCS.append('mean')
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We # Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
# check the CUDA version -- if at least 9.1, then put the bmm # check the CUDA version -- if at least 9.1, then put the bmm
# functions on the fp16 list. Otherwise, put them on the fp32 list. # functions on the fp16 list. Otherwise, put them on the fp32 list.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment