Commit 94667636 authored by rusty1s's avatar rusty1s
Browse files

added 3.5 dependency

parent b907ef2e
...@@ -35,7 +35,7 @@ def get_extensions(): ...@@ -35,7 +35,7 @@ def get_extensions():
if WITH_CUDA: if WITH_CUDA:
Extension = CUDAExtension Extension = CUDAExtension
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
extra_compile_args['cxx'] += ['-O0'] # extra_compile_args['cxx'] += ['-O0']
nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr'] nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
...@@ -52,12 +52,12 @@ def get_extensions(): ...@@ -52,12 +52,12 @@ def get_extensions():
for main in main_files: for main in main_files:
name = main.split(os.sep)[-1][:-4] name = main.split(os.sep)[-1][:-4]
sources = [main, osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')] sources = [main, osp.join(extensions_dir, 'cpu', name + '_cpu.cpp')]
if WITH_CUDA: if WITH_CUDA:
sources += [osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')] sources += [osp.join(extensions_dir, 'cuda', name + '_cuda.cu')]
extension = Extension( extension = Extension(
f'torch_scatter._{name}', 'torch_scatter._' + name,
sources, sources,
include_dirs=[extensions_dir], include_dirs=[extensions_dir],
define_macros=define_macros, define_macros=define_macros,
...@@ -82,7 +82,7 @@ setup( ...@@ -82,7 +82,7 @@ setup(
description='PyTorch Extension Library of Optimized Scatter Operations', description='PyTorch Extension Library of Optimized Scatter Operations',
keywords=['pytorch', 'scatter', 'segment', 'gather'], keywords=['pytorch', 'scatter', 'segment', 'gather'],
license='MIT', license='MIT',
python_requires='>=3.6', python_requires='>=3.5',
install_requires=install_requires, install_requires=install_requires,
setup_requires=setup_requires, setup_requires=setup_requires,
tests_require=tests_require, tests_require=tests_require,
......
...@@ -91,10 +91,10 @@ def test_forward(test, reduce, dtype, device): ...@@ -91,10 +91,10 @@ def test_forward(test, reduce, dtype, device):
dim = test['dim'] dim = test['dim']
expected = tensor(test[reduce], dtype, device) expected = tensor(test[reduce], dtype, device)
out = getattr(torch_scatter, f'scatter_{reduce}')(src, index, dim) out = getattr(torch_scatter, 'scatter_' + reduce)(src, index, dim)
if isinstance(out, tuple): if isinstance(out, tuple):
out, arg_out = out out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) arg_expected = tensor(test['arg_' + reduce], torch.long, device)
assert torch.all(arg_out == arg_expected) assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected) assert torch.all(out == expected)
...@@ -121,7 +121,7 @@ def test_out(test, reduce, dtype, device): ...@@ -121,7 +121,7 @@ def test_out(test, reduce, dtype, device):
out = torch.full_like(expected, -2) out = torch.full_like(expected, -2)
getattr(torch_scatter, f'scatter_{reduce}')(src, index, dim, out) getattr(torch_scatter, 'scatter_' + reduce)(src, index, dim, out)
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
expected = expected - 2 expected = expected - 2
...@@ -150,9 +150,9 @@ def test_non_contiguous(test, reduce, dtype, device): ...@@ -150,9 +150,9 @@ def test_non_contiguous(test, reduce, dtype, device):
if index.dim() > 1: if index.dim() > 1:
index = index.transpose(0, 1).contiguous().transpose(0, 1) index = index.transpose(0, 1).contiguous().transpose(0, 1)
out = getattr(torch_scatter, f'scatter_{reduce}')(src, index, dim) out = getattr(torch_scatter, 'scatter_' + reduce)(src, index, dim)
if isinstance(out, tuple): if isinstance(out, tuple):
out, arg_out = out out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) arg_expected = tensor(test['arg_' + reduce], torch.long, device)
assert torch.all(arg_out == arg_expected) assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected) assert torch.all(out == expected)
...@@ -91,17 +91,17 @@ def test_forward(test, reduce, dtype, device): ...@@ -91,17 +91,17 @@ def test_forward(test, reduce, dtype, device):
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device) expected = tensor(test[reduce], dtype, device)
out = getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr) out = getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr)
if isinstance(out, tuple): if isinstance(out, tuple):
out, arg_out = out out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) arg_expected = tensor(test['arg_' + reduce], torch.long, device)
assert torch.all(arg_out == arg_expected) assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected) assert torch.all(out == expected)
out = getattr(torch_scatter, f'segment_{reduce}_coo')(src, index) out = getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index)
if isinstance(out, tuple): if isinstance(out, tuple):
out, arg_out = out out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) arg_expected = tensor(test['arg_' + reduce], torch.long, device)
assert torch.all(arg_out == arg_expected) assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected) assert torch.all(out == expected)
...@@ -129,12 +129,12 @@ def test_out(test, reduce, dtype, device): ...@@ -129,12 +129,12 @@ def test_out(test, reduce, dtype, device):
out = torch.full_like(expected, -2) out = torch.full_like(expected, -2)
getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr, out) getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr, out)
assert torch.all(out == expected) assert torch.all(out == expected)
out.fill_(-2) out.fill_(-2)
getattr(torch_scatter, f'segment_{reduce}_coo')(src, index, out) getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index, out)
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
expected = expected - 2 expected = expected - 2
...@@ -165,16 +165,16 @@ def test_non_contiguous(test, reduce, dtype, device): ...@@ -165,16 +165,16 @@ def test_non_contiguous(test, reduce, dtype, device):
if indptr.dim() > 1: if indptr.dim() > 1:
indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)
out = getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr) out = getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr)
if isinstance(out, tuple): if isinstance(out, tuple):
out, arg_out = out out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) arg_expected = tensor(test['arg_' + reduce], torch.long, device)
assert torch.all(arg_out == arg_expected) assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected) assert torch.all(out == expected)
out = getattr(torch_scatter, f'segment_{reduce}_coo')(src, index) out = getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index)
if isinstance(out, tuple): if isinstance(out, tuple):
out, arg_out = out out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) arg_expected = tensor(test['arg_' + reduce], torch.long, device)
assert torch.all(arg_out == arg_expected) assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected) assert torch.all(out == expected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment