Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
85940068
Commit
85940068
authored
Jan 30, 2020
by
rusty1s
Browse files
build docs
parent
920136ee
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
76 additions
and
7 deletions
+76
-7
setup.py
setup.py
+3
-1
torch_scatter/scatter.py
torch_scatter/scatter.py
+21
-2
torch_scatter/segment_coo.py
torch_scatter/segment_coo.py
+27
-2
torch_scatter/segment_csr.py
torch_scatter/segment_csr.py
+25
-2
No files found.
setup.py
View file @
85940068
...
@@ -14,6 +14,8 @@ if os.getenv('FORCE_CUDA', '0') == '1':
...
@@ -14,6 +14,8 @@ if os.getenv('FORCE_CUDA', '0') == '1':
if
os
.
getenv
(
'FORCE_NON_CUDA'
,
'0'
)
==
'1'
:
if
os
.
getenv
(
'FORCE_NON_CUDA'
,
'0'
)
==
'1'
:
WITH_CUDA
=
False
WITH_CUDA
=
False
BUILD_DOCS
=
os
.
getenv
(
'BUILD_DOCS'
,
'0'
)
==
'1'
def
get_extensions
():
def
get_extensions
():
Extension
=
CppExtension
Extension
=
CppExtension
...
@@ -74,7 +76,7 @@ setup(
...
@@ -74,7 +76,7 @@ setup(
install_requires
=
install_requires
,
install_requires
=
install_requires
,
setup_requires
=
setup_requires
,
setup_requires
=
setup_requires
,
tests_require
=
tests_require
,
tests_require
=
tests_require
,
ext_modules
=
get_extensions
(),
ext_modules
=
get_extensions
()
if
not
BUILD_DOCS
else
[]
,
cmdclass
=
{
cmdclass
=
{
'build_ext'
:
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
)
'build_ext'
:
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
)
},
},
...
...
torch_scatter/scatter.py
View file @
85940068
import
warnings
import
os.path
as
osp
import
os.path
as
osp
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
import
torch
import
torch
torch
.
ops
.
load_library
(
try
:
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_scatter.so'
))
torch
.
ops
.
load_library
(
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_scatter.so'
))
except
OSError
:
warnings
.
warn
(
'Failed to load `scatter` binaries.'
)
def
placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
,
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
])
->
torch
.
Tensor
:
raise
ImportError
def
arg_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
,
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
ImportError
torch
.
ops
.
torch_scatter
.
scatter_sum
=
placeholder
torch
.
ops
.
torch_scatter
.
scatter_mean
=
placeholder
torch
.
ops
.
torch_scatter
.
scatter_min
=
arg_placeholder
torch
.
ops
.
torch_scatter
.
scatter_max
=
arg_placeholder
@
torch
.
jit
.
script
@
torch
.
jit
.
script
...
...
torch_scatter/segment_coo.py
View file @
85940068
import
warnings
import
os.path
as
osp
import
os.path
as
osp
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
import
torch
import
torch
torch
.
ops
.
load_library
(
try
:
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_segment_coo.so'
))
torch
.
ops
.
load_library
(
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_segment_coo.so'
))
except
OSError
:
warnings
.
warn
(
'Failed to load `segment_coo` binaries.'
)
def
segment_coo_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
])
->
torch
.
Tensor
:
raise
ImportError
def
segment_coo_with_arg_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
ImportError
def
gather_coo_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
raise
ImportError
torch
.
ops
.
torch_scatter
.
segment_sum_coo
=
segment_coo_placeholder
torch
.
ops
.
torch_scatter
.
segment_mean_coo
=
segment_coo_placeholder
torch
.
ops
.
torch_scatter
.
segment_min_coo
=
segment_coo_with_arg_placeholder
torch
.
ops
.
torch_scatter
.
segment_max_coo
=
segment_coo_with_arg_placeholder
torch
.
ops
.
torch_scatter
.
gather_coo
=
gather_coo_placeholder
@
torch
.
jit
.
script
@
torch
.
jit
.
script
...
...
torch_scatter/segment_csr.py
View file @
85940068
import
warnings
import
os.path
as
osp
import
os.path
as
osp
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
import
torch
import
torch
torch
.
ops
.
load_library
(
try
:
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_segment_csr.so'
))
torch
.
ops
.
load_library
(
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_segment_csr.so'
))
except
OSError
:
warnings
.
warn
(
'Failed to load `segment_csr` binaries.'
)
def
segment_csr_placeholder
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
raise
ImportError
def
segment_csr_with_arg_placeholder
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
ImportError
def
gather_csr_placeholder
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
raise
ImportError
torch
.
ops
.
torch_scatter
.
segment_sum_csr
=
segment_csr_placeholder
torch
.
ops
.
torch_scatter
.
segment_mean_csr
=
segment_csr_placeholder
torch
.
ops
.
torch_scatter
.
segment_min_csr
=
segment_csr_with_arg_placeholder
torch
.
ops
.
torch_scatter
.
segment_max_csr
=
segment_csr_with_arg_placeholder
torch
.
ops
.
torch_scatter
.
gather_csr
=
gather_csr_placeholder
@
torch
.
jit
.
script
@
torch
.
jit
.
script
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment