Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
fe79c969
Commit
fe79c969
authored
Feb 11, 2020
by
rusty1s
Browse files
typo
parent
b4a9e5d5
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
77 deletions
+23
-77
setup.py
setup.py
+9
-3
torch_scatter/__init__.py
torch_scatter/__init__.py
+8
-0
torch_scatter/scatter.py
torch_scatter/scatter.py
+2
-16
torch_scatter/segment_coo.py
torch_scatter/segment_coo.py
+2
-30
torch_scatter/segment_csr.py
torch_scatter/segment_csr.py
+2
-28
No files found.
setup.py
View file @
fe79c969
...
...
@@ -35,9 +35,15 @@ def get_extensions():
for
main
in
main_files
:
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
sources
=
[
main
,
osp
.
join
(
extensions_dir
,
'cpu'
,
name
+
'_cpu.cpp'
)]
if
WITH_CUDA
:
sources
+=
[
osp
.
join
(
extensions_dir
,
'cuda'
,
name
+
'_cuda.cu'
)]
sources
=
[
main
]
path
=
osp
.
join
(
extensions_dir
,
'cpu'
,
name
+
'_cpu.cpp'
)
if
osp
.
exists
(
path
):
sources
+=
[
path
]
path
=
osp
.
join
(
extensions_dir
,
'cuda'
,
name
+
'_cuda.cpp'
)
if
WITH_CUDA
and
osp
.
exists
(
path
):
sources
+=
[
path
]
extension
=
Extension
(
'torch_scatter._'
+
name
,
...
...
torch_scatter/__init__.py
View file @
fe79c969
import
os.path
as
osp
import
torch
from
.scatter
import
(
scatter_sum
,
scatter_add
,
scatter_mean
,
scatter_min
,
scatter_max
,
scatter
)
from
.segment_csr
import
(
segment_sum_csr
,
segment_add_csr
,
segment_mean_csr
,
...
...
@@ -9,6 +13,10 @@ from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo,
from
.composite
import
(
scatter_std
,
scatter_logsumexp
,
scatter_softmax
,
scatter_log_softmax
)
torch
.
ops
.
load_library
(
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_version.so'
))
_version
=
torch
.
ops
.
torch_scatter
.
cuda_version
()
__version__
=
'2.0.3'
__all__
=
[
...
...
torch_scatter/scatter.py
View file @
fe79c969
import
warnings
import
os.path
as
osp
from
typing
import
Optional
,
Tuple
...
...
@@ -6,21 +5,8 @@ import torch
from
.utils
import
broadcast
try
:
torch
.
ops
.
load_library
(
torch
.
ops
.
load_library
(
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_scatter.so'
))
except
OSError
:
warnings
.
warn
(
'Failed to load `scatter` binaries.'
)
def
scatter_with_arg_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
,
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
ImportError
return
src
,
index
torch
.
ops
.
torch_scatter
.
scatter_min
=
scatter_with_arg_placeholder
torch
.
ops
.
torch_scatter
.
scatter_max
=
scatter_with_arg_placeholder
@
torch
.
jit
.
script
...
...
torch_scatter/segment_coo.py
View file @
fe79c969
import
warnings
import
os.path
as
osp
from
typing
import
Optional
,
Tuple
import
torch
try
:
torch
.
ops
.
load_library
(
torch
.
ops
.
load_library
(
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_segment_coo.so'
))
except
OSError
:
warnings
.
warn
(
'Failed to load `segment_coo` binaries.'
)
def
segment_coo_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
])
->
torch
.
Tensor
:
raise
ImportError
return
src
def
segment_coo_with_arg_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
ImportError
return
src
,
index
def
gather_coo_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
raise
ImportError
return
src
torch
.
ops
.
torch_scatter
.
segment_sum_coo
=
segment_coo_placeholder
torch
.
ops
.
torch_scatter
.
segment_mean_coo
=
segment_coo_placeholder
torch
.
ops
.
torch_scatter
.
segment_min_coo
=
segment_coo_with_arg_placeholder
torch
.
ops
.
torch_scatter
.
segment_max_coo
=
segment_coo_with_arg_placeholder
torch
.
ops
.
torch_scatter
.
gather_coo
=
gather_coo_placeholder
@
torch
.
jit
.
script
...
...
torch_scatter/segment_csr.py
View file @
fe79c969
import
warnings
import
os.path
as
osp
from
typing
import
Optional
,
Tuple
import
torch
try
:
torch
.
ops
.
load_library
(
torch
.
ops
.
load_library
(
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_segment_csr.so'
))
except
OSError
:
warnings
.
warn
(
'Failed to load `segment_csr` binaries.'
)
def
segment_csr_placeholder
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
raise
ImportError
return
src
def
segment_csr_with_arg_placeholder
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
ImportError
return
src
,
indptr
def
gather_csr_placeholder
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
raise
ImportError
return
src
torch
.
ops
.
torch_scatter
.
segment_sum_csr
=
segment_csr_placeholder
torch
.
ops
.
torch_scatter
.
segment_mean_csr
=
segment_csr_placeholder
torch
.
ops
.
torch_scatter
.
segment_min_csr
=
segment_csr_with_arg_placeholder
torch
.
ops
.
torch_scatter
.
segment_max_csr
=
segment_csr_with_arg_placeholder
torch
.
ops
.
torch_scatter
.
gather_csr
=
gather_csr_placeholder
@
torch
.
jit
.
script
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment