Unverified Commit c01f9bae authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #105 from rusty1s/traceable

[WIP] tracebale functions
parents 2520670a 02a47c46
...@@ -11,6 +11,7 @@ extensions = [ ...@@ -11,6 +11,7 @@ extensions = [
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
'sphinx.ext.githubpages', 'sphinx.ext.githubpages',
'sphinx_autodoc_typehints',
] ]
source_suffix = '.rst' source_suffix = '.rst'
......
Scatter Add Scatter
=========== =======
.. automodule:: torch_scatter .. automodule:: torch_scatter
:noindex: :noindex:
.. autofunction:: scatter_add .. autofunction:: scatter
Scatter Div
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_div
Scatter LogSumExp
=================
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_logsumexp
Scatter Max
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_max
Scatter Mean
============
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_mean
Scatter Min
===========
.. automodule:: torch_scatter
.. autofunction:: scatter_min
Scatter Mul
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_mul
Scatter Std
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_std
Scatter Sub
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_sub
...@@ -7,7 +7,7 @@ This package consists of a small extension library of highly optimized sparse up ...@@ -7,7 +7,7 @@ This package consists of a small extension library of highly optimized sparse up
Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements. Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations. All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.
.. toctree:: .. toctree::
:glob: :glob:
...@@ -15,7 +15,6 @@ All included operations are broadcastable, work on varying data types, and are i ...@@ -15,7 +15,6 @@ All included operations are broadcastable, work on varying data types, and are i
:caption: Package reference :caption: Package reference
functions/* functions/*
composite/*
Indices and tables Indices and tables
================== ==================
......
import platform import os
import os.path as osp import os.path as osp
from glob import glob import sys
import glob
from setuptools import setup, find_packages from setuptools import setup, find_packages
from sys import argv
import torch import torch
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
# Windows users: Edit both of these to contain your VS include path, i.e.: WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
# cxx_extra_compile_args = ['-I{VISUAL_STUDIO_DIR}\\include'] if os.getenv('FORCE_CUDA', '0') == '1':
# nvcc_extra_compile_args = [..., '-I{VISUAL_STUDIO_DIR}\\include'] WITH_CUDA = True
cxx_extra_compile_args = [] if os.getenv('FORCE_NON_CUDA', '0') == '1':
nvcc_extra_compile_args = ['-arch=sm_35', '--expt-relaxed-constexpr'] WITH_CUDA = False
# Windows users: Edit both of these to contain your VS library path, i.e.:
# cxx_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}']
# nvcc_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}']
cxx_extra_link_args = []
nvcc_extra_link_args = []
if platform.system() != 'Windows': def get_extensions():
cxx_extra_compile_args += ['-Wno-unused-variable'] Extension = CppExtension
TORCH_MAJOR = int(torch.__version__.split('.')[0]) define_macros = []
TORCH_MINOR = int(torch.__version__.split('.')[1]) extra_compile_args = {'cxx': [], 'nvcc': []}
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
cxx_extra_compile_args += ['-DVERSION_GE_1_3']
nvcc_extra_compile_args += ['-DVERSION_GE_1_3']
cmdclass = {
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
}
ext_modules = [] # Windows users: Edit both of these to contain your VS include path, i.e.:
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))] # extra_compile_args['cxx'] += ['-I{VISUAL_STUDIO_DIR}\\include']
ext_modules += [ # extra_compile_args['nvcc'] += ['-I{VISUAL_STUDIO_DIR}\\include']
CppExtension(f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'],
extra_compile_args=cxx_extra_compile_args,
extra_link_args=cxx_extra_link_args) for ext in exts
]
if CUDA_HOME is not None and '--cpu' not in argv: if WITH_CUDA:
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cuda', '*.cpp'))] Extension = CUDAExtension
ext_modules += [ define_macros += [('WITH_CUDA', None)]
CUDAExtension( nvcc_flags = os.getenv('NVCC_FLAGS', '')
f'torch_scatter.{ext}_cuda', nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
[f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'], extra_compile_args={ nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
'cxx': cxx_extra_compile_args, extra_compile_args['cxx'] += ['-O0']
'nvcc': nvcc_extra_compile_args, extra_compile_args['nvcc'] += nvcc_flags
}, extra_link_args=nvcc_extra_link_args) for ext in exts
] if sys.platform == 'win32':
if '--cpu' in argv: extra_compile_args['cxx'] += ['/MP']
argv.remove('--cpu')
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
extensions = []
for main in main_files:
name = main.split(os.sep)[-1][:-4]
sources = [main, osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')]
if WITH_CUDA:
sources += [osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')]
extension = Extension(
f'torch_scatter._{name}',
sources,
include_dirs=[extensions_dir],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
extensions += [extension]
return extensions
__version__ = '1.5.0'
url = 'https://github.com/rusty1s/pytorch_scatter'
install_requires = [] install_requires = []
setup_requires = ['pytest-runner'] setup_requires = ['pytest-runner']
...@@ -61,17 +64,19 @@ tests_require = ['pytest', 'pytest-cov'] ...@@ -61,17 +64,19 @@ tests_require = ['pytest', 'pytest-cov']
setup( setup(
name='torch_scatter', name='torch_scatter',
version=__version__, version='2.0.0',
description='PyTorch Extension Library of Optimized Scatter Operations',
author='Matthias Fey', author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de', author_email='matthias.fey@tu-dortmund.de',
url=url, url='https://github.com/rusty1s/pytorch_scatter',
download_url='{}/archive/{}.tar.gz'.format(url, __version__), description='PyTorch Extension Library of Optimized Scatter Operations',
keywords=['pytorch', 'scatter', 'segment'], keywords=['pytorch', 'scatter', 'segment', 'gather'],
license='MIT',
install_requires=install_requires, install_requires=install_requires,
setup_requires=setup_requires, setup_requires=setup_requires,
tests_require=tests_require, tests_require=tests_require,
ext_modules=ext_modules, ext_modules=get_extensions(),
cmdclass=cmdclass, cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
},
packages=find_packages(), packages=find_packages(),
) )
import torch
from torch_scatter import scatter_logsumexp
def test_logsumexp():
src = torch.tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, -100])
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
out = scatter_logsumexp(src, index)
out0 = torch.logsumexp(torch.tensor([0.5, 0.5]), dim=-1)
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2]), dim=-1)
out2 = torch.logsumexp(torch.tensor(7, dtype=torch.float), dim=-1)
out3 = torch.logsumexp(torch.tensor([], dtype=torch.float), dim=-1)
out4 = torch.tensor(-1, dtype=torch.float)
expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
assert torch.allclose(out, expected)
from itertools import product
import pytest
import torch import torch
from torch_scatter.composite import scatter_log_softmax, scatter_softmax from torch_scatter import scatter_log_softmax, scatter_softmax
from test.utils import devices, tensor, grad_dtypes
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) def test_softmax():
def test_softmax(dtype, device): src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_softmax(src, index) out = scatter_softmax(src, index)
out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1) out0 = torch.softmax(torch.tensor([0.2, 0.2]), dim=-1)
out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) out1 = torch.softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1) out2 = torch.softmax(torch.tensor([7], dtype=torch.float), dim=-1)
out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype), out4 = torch.softmax(torch.tensor([-1, float('-inf')]), dim=-1)
dim=-1)
expected = torch.stack([ expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0).to(device) ], dim=0)
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) def test_log_softmax():
def test_softmax_broadcasting(dtype, device): src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
src = torch.randn(10, 5, dtype=dtype, device=device) index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
index = tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
out = scatter_softmax(src, index, dim=0).view(5, 2, 5)
out = out.sum(dim=1)
assert torch.allclose(out, torch.ones_like(out))
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_log_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_log_softmax(src, index) out = scatter_log_softmax(src, index)
out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1) out0 = torch.log_softmax(torch.tensor([0.2, 0.2]), dim=-1)
out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1) out2 = torch.log_softmax(torch.tensor([7], dtype=torch.float), dim=-1)
out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype), out4 = torch.log_softmax(torch.tensor([-1, float('-inf')]), dim=-1)
dim=-1)
expected = torch.stack([ expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0).to(device) ], dim=0)
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
import torch
from torch_scatter import scatter_std
def test_std():
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=torch.float)
index = torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], dtype=torch.long)
out = scatter_std(src, index, dim=-1, unbiased=True)
std = src.std(dim=-1, unbiased=True)[0]
expected = torch.tensor([[std, 0], [0, std]])
assert torch.allclose(out, expected)
from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
import torch_scatter
from .utils import grad_dtypes as dtypes, devices, tensor
funcs = ['add', 'sub', 'mul', 'div', 'mean']
indices = [2, 0, 1, 1, 0]
@pytest.mark.parametrize('func,device', product(funcs, devices))
def test_backward(func, device):
index = torch.tensor(indices, dtype=torch.long, device=device)
src = torch.rand((index.size(0), 2), dtype=torch.double, device=device)
src.requires_grad_()
op = getattr(torch_scatter, 'scatter_{}'.format(func))
data = (src, index, 0)
assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True
tests = [{
'name': 'max',
'src': [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]],
'index': [2, 0, 1, 1, 0],
'dim': 0,
'fill_value': 0,
'grad': [[4, 4], [8, 8], [6, 6]],
'expected': [[6, 6], [0, 0], [0, 0], [8, 8], [4, 4]],
}, {
'name': 'min',
'src': [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]],
'index': [2, 0, 1, 1, 0],
'dim': 0,
'fill_value': 3,
'grad': [[4, 4], [8, 8], [6, 6]],
'expected': [[6, 6], [4, 4], [8, 8], [0, 0], [0, 0]],
}]
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_arg_backward(test, dtype, device):
src = tensor(test['src'], dtype, device)
src.requires_grad_()
index = tensor(test['index'], torch.long, device)
grad = tensor(test['grad'], dtype, device)
op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
out, _ = op(src, index, test['dim'], fill_value=test['fill_value'])
out.backward(grad)
assert src.grad.tolist() == test['expected']
...@@ -14,16 +14,6 @@ def test_broadcasting(device): ...@@ -14,16 +14,6 @@ def test_broadcasting(device):
out = scatter_add(src, index, dim=2, dim_size=H) out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W) assert out.size() == (B, C, H, W)
src = torch.randn((B, 1, H, W), device=device)
index = torch.randint(0, H, (B, C, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W)
src = torch.randn((B, 1, H, W), device=device)
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, 1, H, W)
src = torch.randn((B, C, H, W), device=device) src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long) index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H) out = scatter_add(src, index, dim=2, dim_size=H)
......
from itertools import product
import pytest
import torch
import torch_scatter
from .utils import dtypes, devices, tensor
tests = [{
'name': 'add',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 0,
'expected': [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]],
}, {
'name': 'add',
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 0,
'expected': [[6, 5], [6, 8]],
}, {
'name': 'sub',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 9,
'expected': [[9, 9, 5, 6, 6, 9], [7, 5, 5, 9, 9, 9]],
}, {
'name': 'sub',
'src': [[5, 2], [2, 2], [4, 2], [1, 3]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 9,
'expected': [[3, 4], [3, 5]],
}, {
'name': 'mul',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 1,
'expected': [[1, 1, 4, 3, 2, 0], [0, 4, 3, 1, 1, 1]],
}, {
'name': 'mul',
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 1,
'expected': [[5, 6], [8, 15]],
}, {
'name': 'div',
'src': [[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 1,
'expected': [[1, 1, 0.25, 0.5, 0.5, 1], [0.5, 0.25, 0.5, 1, 1, 1]],
}, {
'name': 'div',
'src': [[4, 2], [2, 1], [4, 2], [1, 2]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 1,
'expected': [[0.25, 0.25], [0.125, 0.5]],
}, {
'name': 'mean',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 0,
'expected': [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]],
}, {
'name': 'mean',
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 0,
'expected': [[3, 2.5], [3, 4]],
}, {
'name': 'max',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 0,
'expected': [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]],
'expected_arg': [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]],
}, {
'name': 'max',
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 0,
'expected': [[5, 3], [4, 5]],
'expected_arg': [[0, 3], [2, 1]],
}, {
'name': 'min',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 9,
'expected': [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]],
'expected_arg': [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]],
}, {
'name': 'min',
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 9,
'expected': [[1, 2], [2, 3]],
'expected_arg': [[3, 0], [1, 2]],
}]
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_forward(test, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
expected = tensor(test['expected'], dtype, device)
op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
out = op(src, index, test['dim'], fill_value=test['fill_value'])
if isinstance(out, tuple):
assert out[0].tolist() == expected.tolist()
assert out[1].tolist() == test['expected_arg']
else:
assert out.tolist() == expected.tolist()
...@@ -3,7 +3,7 @@ from itertools import product ...@@ -3,7 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_scatter import gather_coo, gather_csr from torch_scatter import gather_csr, gather_coo
from .utils import tensor, dtypes, devices from .utils import tensor, dtypes, devices
...@@ -54,10 +54,10 @@ def test_forward(test, dtype, device): ...@@ -54,10 +54,10 @@ def test_forward(test, dtype, device):
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test['expected'], dtype, device) expected = tensor(test['expected'], dtype, device)
out = gather_coo(src, index) out = gather_csr(src, indptr)
assert torch.all(out == expected) assert torch.all(out == expected)
out = gather_csr(src, indptr) out = gather_coo(src, index)
assert torch.all(out == expected) assert torch.all(out == expected)
...@@ -68,12 +68,12 @@ def test_backward(test, device): ...@@ -68,12 +68,12 @@ def test_backward(test, device):
index = tensor(test['index'], torch.long, device) index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
assert gradcheck(gather_coo, (src, index, None)) is True
assert gradcheck(gather_csr, (src, indptr, None)) is True assert gradcheck(gather_csr, (src, indptr, None)) is True
assert gradcheck(gather_coo, (src, index, None)) is True
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_gather_out(test, dtype, device): def test_out(test, dtype, device):
src = tensor(test['src'], dtype, device) src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device) index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
...@@ -83,17 +83,17 @@ def test_gather_out(test, dtype, device): ...@@ -83,17 +83,17 @@ def test_gather_out(test, dtype, device):
size[index.dim() - 1] = index.size(-1) size[index.dim() - 1] = index.size(-1)
out = src.new_full(size, -2) out = src.new_full(size, -2)
gather_coo(src, index, out) gather_csr(src, indptr, out)
assert torch.all(out == expected) assert torch.all(out == expected)
out.fill_(-2) out.fill_(-2)
gather_csr(src, indptr, out) gather_coo(src, index, out)
assert torch.all(out == expected) assert torch.all(out == expected)
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_non_contiguous_segment(test, dtype, device): def test_non_contiguous(test, dtype, device):
src = tensor(test['src'], dtype, device) src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device) index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
...@@ -106,8 +106,8 @@ def test_non_contiguous_segment(test, dtype, device): ...@@ -106,8 +106,8 @@ def test_non_contiguous_segment(test, dtype, device):
if indptr.dim() > 1: if indptr.dim() > 1:
indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)
out = gather_coo(src, index) out = gather_csr(src, indptr)
assert torch.all(out == expected) assert torch.all(out == expected)
out = gather_csr(src, indptr) out = gather_coo(src, index)
assert torch.all(out == expected) assert torch.all(out == expected)
from itertools import product
import torch
import pytest
from torch_scatter import scatter_logsumexp
from .utils import devices, tensor, grad_dtypes
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_logsumexp(dtype, device):
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_logsumexp(src, index)
out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1)
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1)
out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype)
out4 = torch.tensor(-1, dtype=dtype)
expected = torch.stack([out0, out1, out2, out3, out4], dim=0).to(device)
assert torch.allclose(out, expected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment