Commit d1dd9466 authored by rusty1s's avatar rusty1s
Browse files

support more types

parent 0807f87f
......@@ -5,12 +5,11 @@ import torch
from torch.autograd import gradcheck
from torch_scatter import segment_coo, segment_csr
from .utils import tensor
from .utils import tensor, dtypes
reductions = ['add', 'mean', 'min', 'max']
grad_reductions = ['add', 'mean']
dtypes = [torch.float]
devices = [torch.device('cuda')]
tests = [
......
import torch
from torch.testing import get_all_dtypes
dtypes = get_all_dtypes()
dtypes.remove(torch.half)
dtypes.remove(torch.short) # PyTorch scatter does not work on short types.
dtypes.remove(torch.bool)
if hasattr(torch, 'bfloat16'):
dtypes.remove(torch.bfloat16)
dtypes = [torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.float, torch.double]
devices = [torch.device('cpu')]
......
......@@ -4,12 +4,12 @@ import torch
def min_value(dtype):
try:
return torch.finfo(dtype).min
except AttributeError:
return torch.info(dtype).min
except TypeError:
return torch.iinfo(dtype).min
def max_value(dtype):
try:
return torch.finfo(dtype).max
except AttributeError:
return torch.info(dtype).max
except TypeError:
return torch.iinfo(dtype).max
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment