Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
94667636
Commit
94667636
authored
Feb 11, 2020
by
rusty1s
Browse files
added 3.5 dependency
parent
b907ef2e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
20 deletions
+20
-20
setup.py
setup.py
+5
-5
test/test_scatter.py
test/test_scatter.py
+5
-5
test/test_segment.py
test/test_segment.py
+10
-10
No files found.
setup.py
View file @
94667636
...
@@ -35,7 +35,7 @@ def get_extensions():
...
@@ -35,7 +35,7 @@ def get_extensions():
if
WITH_CUDA
:
if
WITH_CUDA
:
Extension
=
CUDAExtension
Extension
=
CUDAExtension
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
extra_compile_args
[
'cxx'
]
+=
[
'-O0'
]
#
extra_compile_args['cxx'] += ['-O0']
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
+=
[
'-arch=sm_35'
,
'--expt-relaxed-constexpr'
]
nvcc_flags
+=
[
'-arch=sm_35'
,
'--expt-relaxed-constexpr'
]
...
@@ -52,12 +52,12 @@ def get_extensions():
...
@@ -52,12 +52,12 @@ def get_extensions():
for
main
in
main_files
:
for
main
in
main_files
:
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
sources
=
[
main
,
osp
.
join
(
extensions_dir
,
'cpu'
,
f
'
{
name
}
_cpu.cpp'
)]
sources
=
[
main
,
osp
.
join
(
extensions_dir
,
'cpu'
,
name
+
'
_cpu.cpp'
)]
if
WITH_CUDA
:
if
WITH_CUDA
:
sources
+=
[
osp
.
join
(
extensions_dir
,
'cuda'
,
f
'
{
name
}
_cuda.cu'
)]
sources
+=
[
osp
.
join
(
extensions_dir
,
'cuda'
,
name
+
'
_cuda.cu'
)]
extension
=
Extension
(
extension
=
Extension
(
f
'torch_scatter._
{
name
}
'
,
'torch_scatter._
'
+
name
,
sources
,
sources
,
include_dirs
=
[
extensions_dir
],
include_dirs
=
[
extensions_dir
],
define_macros
=
define_macros
,
define_macros
=
define_macros
,
...
@@ -82,7 +82,7 @@ setup(
...
@@ -82,7 +82,7 @@ setup(
description
=
'PyTorch Extension Library of Optimized Scatter Operations'
,
description
=
'PyTorch Extension Library of Optimized Scatter Operations'
,
keywords
=
[
'pytorch'
,
'scatter'
,
'segment'
,
'gather'
],
keywords
=
[
'pytorch'
,
'scatter'
,
'segment'
,
'gather'
],
license
=
'MIT'
,
license
=
'MIT'
,
python_requires
=
'>=3.
6
'
,
python_requires
=
'>=3.
5
'
,
install_requires
=
install_requires
,
install_requires
=
install_requires
,
setup_requires
=
setup_requires
,
setup_requires
=
setup_requires
,
tests_require
=
tests_require
,
tests_require
=
tests_require
,
...
...
test/test_scatter.py
View file @
94667636
...
@@ -91,10 +91,10 @@ def test_forward(test, reduce, dtype, device):
...
@@ -91,10 +91,10 @@ def test_forward(test, reduce, dtype, device):
dim
=
test
[
'dim'
]
dim
=
test
[
'dim'
]
expected
=
tensor
(
test
[
reduce
],
dtype
,
device
)
expected
=
tensor
(
test
[
reduce
],
dtype
,
device
)
out
=
getattr
(
torch_scatter
,
f
'scatter_
{
reduce
}
'
)(
src
,
index
,
dim
)
out
=
getattr
(
torch_scatter
,
'scatter_
'
+
reduce
)(
src
,
index
,
dim
)
if
isinstance
(
out
,
tuple
):
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
out
,
arg_out
=
out
arg_expected
=
tensor
(
test
[
f
'arg_
{
reduce
}
'
],
torch
.
long
,
device
)
arg_expected
=
tensor
(
test
[
'arg_
'
+
reduce
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
...
@@ -121,7 +121,7 @@ def test_out(test, reduce, dtype, device):
...
@@ -121,7 +121,7 @@ def test_out(test, reduce, dtype, device):
out
=
torch
.
full_like
(
expected
,
-
2
)
out
=
torch
.
full_like
(
expected
,
-
2
)
getattr
(
torch_scatter
,
f
'scatter_
{
reduce
}
'
)(
src
,
index
,
dim
,
out
)
getattr
(
torch_scatter
,
'scatter_
'
+
reduce
)(
src
,
index
,
dim
,
out
)
if
reduce
==
'sum'
or
reduce
==
'add'
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
expected
=
expected
-
2
expected
=
expected
-
2
...
@@ -150,9 +150,9 @@ def test_non_contiguous(test, reduce, dtype, device):
...
@@ -150,9 +150,9 @@ def test_non_contiguous(test, reduce, dtype, device):
if
index
.
dim
()
>
1
:
if
index
.
dim
()
>
1
:
index
=
index
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
index
=
index
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
out
=
getattr
(
torch_scatter
,
f
'scatter_
{
reduce
}
'
)(
src
,
index
,
dim
)
out
=
getattr
(
torch_scatter
,
'scatter_
'
+
reduce
)(
src
,
index
,
dim
)
if
isinstance
(
out
,
tuple
):
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
out
,
arg_out
=
out
arg_expected
=
tensor
(
test
[
f
'arg_
{
reduce
}
'
],
torch
.
long
,
device
)
arg_expected
=
tensor
(
test
[
'arg_
'
+
reduce
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
test/test_segment.py
View file @
94667636
...
@@ -91,17 +91,17 @@ def test_forward(test, reduce, dtype, device):
...
@@ -91,17 +91,17 @@ def test_forward(test, reduce, dtype, device):
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
expected
=
tensor
(
test
[
reduce
],
dtype
,
device
)
expected
=
tensor
(
test
[
reduce
],
dtype
,
device
)
out
=
getattr
(
torch_scatter
,
f
'segment_
{
reduce
}
_csr'
)(
src
,
indptr
)
out
=
getattr
(
torch_scatter
,
'segment_
'
+
reduce
+
'
_csr'
)(
src
,
indptr
)
if
isinstance
(
out
,
tuple
):
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
out
,
arg_out
=
out
arg_expected
=
tensor
(
test
[
f
'arg_
{
reduce
}
'
],
torch
.
long
,
device
)
arg_expected
=
tensor
(
test
[
'arg_
'
+
reduce
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
out
=
getattr
(
torch_scatter
,
f
'segment_
{
reduce
}
_coo'
)(
src
,
index
)
out
=
getattr
(
torch_scatter
,
'segment_
'
+
reduce
+
'
_coo'
)(
src
,
index
)
if
isinstance
(
out
,
tuple
):
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
out
,
arg_out
=
out
arg_expected
=
tensor
(
test
[
f
'arg_
{
reduce
}
'
],
torch
.
long
,
device
)
arg_expected
=
tensor
(
test
[
'arg_
'
+
reduce
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
...
@@ -129,12 +129,12 @@ def test_out(test, reduce, dtype, device):
...
@@ -129,12 +129,12 @@ def test_out(test, reduce, dtype, device):
out
=
torch
.
full_like
(
expected
,
-
2
)
out
=
torch
.
full_like
(
expected
,
-
2
)
getattr
(
torch_scatter
,
f
'segment_
{
reduce
}
_csr'
)(
src
,
indptr
,
out
)
getattr
(
torch_scatter
,
'segment_
'
+
reduce
+
'
_csr'
)(
src
,
indptr
,
out
)
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
out
.
fill_
(
-
2
)
out
.
fill_
(
-
2
)
getattr
(
torch_scatter
,
f
'segment_
{
reduce
}
_coo'
)(
src
,
index
,
out
)
getattr
(
torch_scatter
,
'segment_
'
+
reduce
+
'
_coo'
)(
src
,
index
,
out
)
if
reduce
==
'sum'
or
reduce
==
'add'
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
expected
=
expected
-
2
expected
=
expected
-
2
...
@@ -165,16 +165,16 @@ def test_non_contiguous(test, reduce, dtype, device):
...
@@ -165,16 +165,16 @@ def test_non_contiguous(test, reduce, dtype, device):
if
indptr
.
dim
()
>
1
:
if
indptr
.
dim
()
>
1
:
indptr
=
indptr
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
indptr
=
indptr
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
out
=
getattr
(
torch_scatter
,
f
'segment_
{
reduce
}
_csr'
)(
src
,
indptr
)
out
=
getattr
(
torch_scatter
,
'segment_
'
+
reduce
+
'
_csr'
)(
src
,
indptr
)
if
isinstance
(
out
,
tuple
):
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
out
,
arg_out
=
out
arg_expected
=
tensor
(
test
[
f
'arg_
{
reduce
}
'
],
torch
.
long
,
device
)
arg_expected
=
tensor
(
test
[
'arg_
'
+
reduce
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
out
=
getattr
(
torch_scatter
,
f
'segment_
{
reduce
}
_coo'
)(
src
,
index
)
out
=
getattr
(
torch_scatter
,
'segment_
'
+
reduce
+
'
_coo'
)(
src
,
index
)
if
isinstance
(
out
,
tuple
):
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
out
,
arg_out
=
out
arg_expected
=
tensor
(
test
[
f
'arg_
{
reduce
}
'
],
torch
.
long
,
device
)
arg_expected
=
tensor
(
test
[
'arg_
'
+
reduce
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment