Unverified Commit 8e6635b3 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

fix test (#340)

parent 003abd58
......@@ -3,8 +3,7 @@ from itertools import product
import pytest
import torch
from torch_scatter import scatter
from .utils import reductions, devices
from torch_scatter.testing import devices, reductions
@pytest.mark.parametrize('reduce,device', product(reductions, devices))
......
......@@ -3,9 +3,8 @@ from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
from torch_scatter import gather_csr, gather_coo
from .utils import tensor, dtypes, devices
from torch_scatter import gather_coo, gather_csr
from torch_scatter.testing import devices, dtypes, tensor
tests = [
{
......
......@@ -3,8 +3,7 @@ from itertools import product
import pytest
import torch
import torch_scatter
from .utils import reductions, tensor, dtypes
from torch_scatter.testing import dtypes, reductions, tensor
tests = [
{
......
......@@ -4,8 +4,7 @@ import pytest
import torch
import torch_scatter
from torch.autograd import gradcheck
from .utils import devices, dtypes, reductions, tensor
from torch_scatter.testing import devices, dtypes, reductions, tensor
reductions = reductions + ['mul']
......
......@@ -2,10 +2,9 @@ from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
import torch_scatter
from .utils import reductions, tensor, dtypes, devices
from torch.autograd import gradcheck
from torch_scatter.testing import devices, dtypes, reductions, tensor
tests = [
{
......
......@@ -2,10 +2,9 @@ from itertools import product
import pytest
import torch
from torch_scatter import scatter, segment_coo, gather_coo
from torch_scatter import segment_csr, gather_csr
from .utils import reductions, tensor, grad_dtypes, devices
from torch_scatter import (gather_coo, gather_csr, scatter, segment_coo,
segment_csr)
from torch_scatter.testing import devices, grad_dtypes, reductions, tensor
@pytest.mark.parametrize('reduce,dtype,device',
......
from typing import Any
import torch
reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.half, torch.bfloat16, torch.float, torch.double,
torch.int, torch.long]
dtypes = [
torch.half, torch.bfloat16, torch.float, torch.double, torch.int,
torch.long
]
grad_dtypes = [torch.float, torch.double]
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]
devices += [torch.device('cuda:0')]
def tensor(x, dtype, device):
def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, device=device).to(dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment