Unverified Commit 8e6635b3 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

fix test (#340)

parent 003abd58
...@@ -3,8 +3,7 @@ from itertools import product ...@@ -3,8 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_scatter import scatter from torch_scatter import scatter
from torch_scatter.testing import devices, reductions
from .utils import reductions, devices
@pytest.mark.parametrize('reduce,device', product(reductions, devices)) @pytest.mark.parametrize('reduce,device', product(reductions, devices))
......
...@@ -3,9 +3,8 @@ from itertools import product ...@@ -3,9 +3,8 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_scatter import gather_csr, gather_coo from torch_scatter import gather_coo, gather_csr
from torch_scatter.testing import devices, dtypes, tensor
from .utils import tensor, dtypes, devices
tests = [ tests = [
{ {
......
...@@ -3,8 +3,7 @@ from itertools import product ...@@ -3,8 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
import torch_scatter import torch_scatter
from torch_scatter.testing import dtypes, reductions, tensor
from .utils import reductions, tensor, dtypes
tests = [ tests = [
{ {
......
...@@ -4,8 +4,7 @@ import pytest ...@@ -4,8 +4,7 @@ import pytest
import torch import torch
import torch_scatter import torch_scatter
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_scatter.testing import devices, dtypes, reductions, tensor
from .utils import devices, dtypes, reductions, tensor
reductions = reductions + ['mul'] reductions = reductions + ['mul']
......
...@@ -2,10 +2,9 @@ from itertools import product ...@@ -2,10 +2,9 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck
import torch_scatter import torch_scatter
from torch.autograd import gradcheck
from .utils import reductions, tensor, dtypes, devices from torch_scatter.testing import devices, dtypes, reductions, tensor
tests = [ tests = [
{ {
......
...@@ -2,10 +2,9 @@ from itertools import product ...@@ -2,10 +2,9 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_scatter import scatter, segment_coo, gather_coo from torch_scatter import (gather_coo, gather_csr, scatter, segment_coo,
from torch_scatter import segment_csr, gather_csr segment_csr)
from torch_scatter.testing import devices, grad_dtypes, reductions, tensor
from .utils import reductions, tensor, grad_dtypes, devices
@pytest.mark.parametrize('reduce,dtype,device', @pytest.mark.parametrize('reduce,dtype,device',
......
from typing import Any
import torch import torch
reductions = ['sum', 'add', 'mean', 'min', 'max'] reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.half, torch.bfloat16, torch.float, torch.double, dtypes = [
torch.int, torch.long] torch.half, torch.bfloat16, torch.float, torch.double, torch.int,
torch.long
]
grad_dtypes = [torch.float, torch.double] grad_dtypes = [torch.float, torch.double]
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
if torch.cuda.is_available(): if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')] devices += [torch.device('cuda:0')]
def tensor(x, dtype, device): def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, device=device).to(dtype) return None if x is None else torch.tensor(x, device=device).to(dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment