Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
c5f5be51
Commit
c5f5be51
authored
Jan 31, 2020
by
rusty1s
Browse files
implementing convert
parent
3c6dbfa1
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
99 additions
and
94 deletions
+99
-94
csrc/cuda/spspmm_kernel.cu
csrc/cuda/spspmm_kernel.cu
+0
-0
csrc/cuda/unique.cpp
csrc/cuda/unique.cpp
+0
-0
csrc/cuda/unique_kernel.cu
csrc/cuda/unique_kernel.cu
+0
-0
csrc/cuda/utils.cuh
csrc/cuda/utils.cuh
+7
-0
cuda/convert.cpp
cuda/convert.cpp
+0
-21
setup.py
setup.py
+70
-51
torch_sparse/reduce.py
torch_sparse/reduce.py
+0
-4
torch_sparse/storage.py
torch_sparse/storage.py
+22
-18
No files found.
cuda/spspmm_kernel.cu
→
csrc/
cuda/spspmm_kernel.cu
View file @
c5f5be51
File moved
cuda/unique.cpp
→
csrc/
cuda/unique.cpp
View file @
c5f5be51
File moved
cuda/unique_kernel.cu
→
csrc/
cuda/unique_kernel.cu
View file @
c5f5be51
File moved
csrc/cuda/utils.cuh
0 → 100644
View file @
c5f5be51
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
cuda/convert.cpp
deleted
100644 → 0
View file @
3c6dbfa1
#include <torch/script.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
torch
::
Tensor
ind2ptr_cuda
(
torch
::
Tensor
ind
,
int64_t
M
);
torch
::
Tensor
ptr2ind_cuda
(
torch
::
Tensor
ptr
,
int64_t
E
);
torch
::
Tensor
ind2ptr
(
torch
::
Tensor
ind
,
int64_t
M
)
{
CHECK_CUDA
(
ind
);
return
ind2ptr_cuda
(
ind
,
M
);
}
torch
::
Tensor
ptr2ind
(
torch
::
Tensor
ptr
,
int64_t
E
)
{
CHECK_CUDA
(
ptr
);
return
ptr2ind_cuda
(
ptr
,
E
);
}
static
auto
registry
=
torch
::
RegisterOperators
(
"torch_sparse_cuda::ind2ptr"
,
&
ind2ptr
)
.
op
(
"torch_sparse_cuda::ptr2ind"
,
&
ptr2ind
);
setup.py
View file @
c5f5be51
import
platform
import
os
import
os.path
as
osp
import
os.path
as
osp
from
glob
import
glob
import
sys
import
glob
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
from
sys
import
argv
import
torch
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
from
torch.utils.cpp_extension
import
BuildExtension
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
cxx_extra_compile_args
=
[]
WITH_CUDA
=
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
nvcc_extra_compile_args
=
[
'-arch=sm_35'
,
'--expt-relaxed-constexpr'
]
if
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
WITH_CUDA
=
True
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
if
os
.
getenv
(
'FORCE_NON_CUDA'
,
'0'
)
==
'1'
:
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
WITH_CUDA
=
False
extra_compile_args
=
[]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
2
):
BUILD_DOCS
=
os
.
getenv
(
'BUILD_DOCS'
,
'0'
)
==
'1'
cxx_extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
nvcc_extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
cmdclass
=
{
def
get_extensions
():
'build_ext'
:
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
)
Extension
=
CppExtension
}
define_macros
=
[]
extra_compile_args
=
{
'cxx'
:
[],
'nvcc'
:
[]}
ext_modules
=
[]
extra_link_args
=
[]
exts
=
[
e
.
split
(
osp
.
sep
)[
-
1
][:
-
4
]
for
e
in
glob
(
osp
.
join
(
'cpu'
,
'*.cpp'
))]
ext_modules
+=
[
# Windows users: Edit both of these to contain your VS include path, i.e.:
CppExtension
(
f
'torch_sparse.
{
ext
}
_cpu'
,
[
f
'cpu/
{
ext
}
.cpp'
],
# extra_compile_args['cxx'] += ['-I{VISUAL_STUDIO_DIR}\\include']
extra_compile_args
=
cxx_extra_compile_args
)
for
ext
in
exts
# extra_compile_args['nvcc'] += ['-I{VISUAL_STUDIO_DIR}\\include']
]
if
WITH_CUDA
:
if
CUDA_HOME
is
not
None
and
'--cpu'
not
in
argv
:
Extension
=
CUDAExtension
if
platform
.
system
()
==
'Windows'
:
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
extra_link_args
=
[
'cusparse.lib'
]
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
else
:
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
extra_link_args
=
[
'-lcusparse'
,
'-l'
,
'cusparse'
]
nvcc_flags
+=
[
'-arch=sm_35'
,
'--expt-relaxed-constexpr'
]
extra_compile_args
[
'cxx'
]
+=
[
'-O0'
]
exts
=
[
e
.
split
(
osp
.
sep
)[
-
1
][:
-
4
]
for
e
in
glob
(
osp
.
join
(
'cuda'
,
'*.cpp'
))]
extra_compile_args
[
'nvcc'
]
+=
nvcc_flags
ext_modules
+=
[
if
sys
.
platform
==
'win32'
:
CUDAExtension
(
f
'torch_sparse.
{
ext
}
_cuda'
,
extra_link_args
=
[
'cusparse.lib'
]
[
f
'cuda/
{
ext
}
.cpp'
,
f
'cuda/
{
ext
}
_kernel.cu'
],
else
:
extra_compile_args
=
{
extra_link_args
=
[
'-lcusparse'
,
'-l'
,
'cusparse'
]
'cxx'
:
cxx_extra_compile_args
,
'nvcc'
:
nvcc_extra_compile_args
,
if
sys
.
platform
==
'win32'
:
},
extra_compile_args
[
'cxx'
]
+=
[
'/MP'
]
extensions_dir
=
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'csrc'
)
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
extensions
=
[]
for
main
in
main_files
:
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
sources
=
[
main
,
osp
.
join
(
extensions_dir
,
'cpu'
,
f
'
{
name
}
_cpu.cpp'
)]
if
WITH_CUDA
:
sources
+=
[
osp
.
join
(
extensions_dir
,
'cuda'
,
f
'
{
name
}
_cuda.cu'
)]
extension
=
Extension
(
f
'torch_sparse._
{
name
}
'
,
sources
,
include_dirs
=
[
extensions_dir
],
define_macros
=
define_macros
,
extra_compile_args
=
extra_compile_args
,
extra_link_args
=
extra_link_args
,
extra_link_args
=
extra_link_args
,
)
for
ext
in
exts
)
]
extensions
+=
[
extension
]
if
'--cpu'
in
argv
:
argv
.
remove
(
'--cpu'
)
return
extensions
__version__
=
'0.4.3'
__version__
=
'1.0.0'
url
=
'https://github.com/rusty1s/pytorch_sparse'
install_requires
=
[
'scipy'
]
install_requires
=
[
'scipy'
]
setup_requires
=
[
'pytest-runner'
]
setup_requires
=
[
'pytest-runner'
]
...
@@ -58,18 +75,20 @@ tests_require = ['pytest', 'pytest-cov']
...
@@ -58,18 +75,20 @@ tests_require = ['pytest', 'pytest-cov']
setup
(
setup
(
name
=
'torch_sparse'
,
name
=
'torch_sparse'
,
version
=
__version__
,
version
=
'1.0.0'
,
description
=
(
'PyTorch Extension Library of Optimized Autograd Sparse '
'Matrix Operations'
),
author
=
'Matthias Fey'
,
author
=
'Matthias Fey'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
url
=
url
,
url
=
'https://github.com/rusty1s/pytorch_sparse'
,
download_url
=
'{}/archive/{}.tar.gz'
.
format
(
url
,
__version__
),
description
=
(
'PyTorch Extension Library of Optimized Autograd Sparse '
'Matrix Operations'
),
keywords
=
[
'pytorch'
,
'sparse'
,
'sparse-matrices'
,
'autograd'
],
keywords
=
[
'pytorch'
,
'sparse'
,
'sparse-matrices'
,
'autograd'
],
license
=
'MIT'
,
install_requires
=
install_requires
,
install_requires
=
install_requires
,
setup_requires
=
setup_requires
,
setup_requires
=
setup_requires
,
tests_require
=
tests_require
,
tests_require
=
tests_require
,
ext_modules
=
ext_modules
,
ext_modules
=
get_extensions
()
if
not
BUILD_DOCS
else
[],
cmdclass
=
cmdclass
,
cmdclass
=
{
'build_ext'
:
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
)
},
packages
=
find_packages
(),
packages
=
find_packages
(),
)
)
torch_sparse/reduce.py
View file @
c5f5be51
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
torch_scatter
from
torch_scatter
import
scatter
,
segment_csr
from
torch_scatter
import
scatter
,
segment_csr
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
...
@@ -32,7 +30,6 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
...
@@ -32,7 +30,6 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
return
torch
.
tensor
(
1
,
dtype
=
src
.
dtype
(),
device
=
src
.
device
())
return
torch
.
tensor
(
1
,
dtype
=
src
.
dtype
(),
device
=
src
.
device
())
else
:
else
:
raise
ValueError
raise
ValueError
else
:
else
:
if
dim
<
0
:
if
dim
<
0
:
dim
=
src
.
dim
()
+
dim
dim
=
src
.
dim
()
+
dim
...
@@ -67,7 +64,6 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
...
@@ -67,7 +64,6 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
return
value
.
max
(
dim
=
dim
-
1
)[
0
]
return
value
.
max
(
dim
=
dim
-
1
)[
0
]
else
:
else
:
raise
ValueError
raise
ValueError
else
:
else
:
raise
ValueError
raise
ValueError
...
...
torch_sparse/storage.py
View file @
c5f5be51
import
warnings
import
warnings
import
os.path
as
osp
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
import
torch
import
torch
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_sparse.utils
import
Final
from
torch_sparse.utils
import
Final
try
:
torch
.
ops
.
load_library
(
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_convert.so'
))
except
OSError
:
warnings
.
warn
(
'Failed to load `convert` binaries.'
)
def
ind2ptr_placeholder
(
ind
:
torch
.
Tensor
,
M
:
int
)
->
torch
.
Tensor
:
raise
ImportError
return
ind
def
ptr2ind_placeholder
(
ptr
:
torch
.
Tensor
,
E
:
int
)
->
torch
.
Tensor
:
raise
ImportError
return
ptr
torch
.
ops
.
torch_sparse
.
ind2ptr
=
ind2ptr_placeholder
torch
.
ops
.
torch_sparse
.
ptr2ind
=
ptr2ind_placeholder
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
...
@@ -147,16 +165,7 @@ class SparseStorage(object):
...
@@ -147,16 +165,7 @@ class SparseStorage(object):
rowptr
=
self
.
_rowptr
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
if
rowptr
is
not
None
:
if
rowptr
.
is_cuda
:
row
=
torch
.
ops
.
torch_sparse
.
ptr2ind
(
rowptr
,
self
.
_col
.
numel
())
row
=
torch
.
ops
.
torch_sparse_cuda
.
ptr2ind
(
rowptr
,
self
.
_col
.
numel
())
else
:
if
rowptr
.
is_cuda
:
row
=
torch
.
ops
.
torch_sparse_cuda
.
ptr2ind
(
rowptr
,
self
.
_col
.
numel
())
else
:
row
=
torch
.
ops
.
torch_sparse_cpu
.
ptr2ind
(
rowptr
,
self
.
_col
.
numel
())
self
.
_row
=
row
self
.
_row
=
row
return
row
return
row
...
@@ -172,12 +181,7 @@ class SparseStorage(object):
...
@@ -172,12 +181,7 @@ class SparseStorage(object):
row
=
self
.
_row
row
=
self
.
_row
if
row
is
not
None
:
if
row
is
not
None
:
if
row
.
is_cuda
:
rowptr
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
row
,
self
.
_sparse_sizes
[
0
])
rowptr
=
torch
.
ops
.
torch_sparse_cuda
.
ind2ptr
(
row
,
self
.
_sparse_sizes
[
0
])
else
:
rowptr
=
torch
.
ops
.
torch_sparse_cpu
.
ind2ptr
(
row
,
self
.
_sparse_sizes
[
0
])
self
.
_rowptr
=
rowptr
self
.
_rowptr
=
rowptr
return
rowptr
return
rowptr
...
@@ -284,8 +288,8 @@ class SparseStorage(object):
...
@@ -284,8 +288,8 @@ class SparseStorage(object):
csr2csc
=
self
.
_csr2csc
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
if
csr2csc
is
not
None
:
colptr
=
torch
.
ops
.
torch_sparse
_cpu
.
ind2ptr
(
colptr
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
self
.
_col
[
csr2csc
],
self
.
_col
[
csr2csc
],
self
.
_sparse_sizes
[
1
])
self
.
_sparse_sizes
[
1
])
else
:
else
:
colptr
=
self
.
_col
.
new_zeros
(
self
.
_sparse_sizes
[
1
]
+
1
)
colptr
=
self
.
_col
.
new_zeros
(
self
.
_sparse_sizes
[
1
]
+
1
)
torch
.
cumsum
(
self
.
colcount
(),
dim
=
0
,
out
=
colptr
[
1
:])
torch
.
cumsum
(
self
.
colcount
(),
dim
=
0
,
out
=
colptr
[
1
:])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment