Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
1c4fdfe2
Commit
1c4fdfe2
authored
Oct 14, 2019
by
rusty1s
Browse files
pytorch 1.3 support
parent
573ad113
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
65 additions
and
37 deletions
+65
-37
.travis.yml
.travis.yml
+1
-1
cpu/compat.h
cpu/compat.h
+5
-0
cpu/spspmm.cpp
cpu/spspmm.cpp
+10
-8
cuda/compat.cuh
cuda/compat.cuh
+5
-0
cuda/spspmm_kernel.cu
cuda/spspmm_kernel.cu
+22
-19
cuda/unique_kernel.cu
cuda/unique_kernel.cu
+3
-1
setup.py
setup.py
+18
-7
torch_sparse/__init__.py
torch_sparse/__init__.py
+1
-1
No files found.
.travis.yml
View file @
1c4fdfe2
...
...
@@ -17,7 +17,7 @@ before_install:
-
export CXX="g++-4.9"
install
:
-
pip install numpy
-
pip install -
q torch
-f https://download.pytorch.org/whl/nightly/cpu/torch.html
-
pip install -
-pre torch torchvision
-f https://download.pytorch.org/whl/nightly/cpu/torch
_nightly
.html
-
pip install pycodestyle
-
pip install flake8
-
pip install codecov
...
...
cpu/compat.h
0 → 100644
View file @
1c4fdfe2
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
cpu/spspmm.cpp
View file @
1c4fdfe2
#include <torch/extension.h>
#include "compat.h"
at
::
Tensor
degree
(
at
::
Tensor
row
,
int64_t
num_nodes
)
{
auto
zero
=
at
::
zeros
(
num_nodes
,
row
.
options
());
auto
one
=
at
::
ones
(
row
.
size
(
0
),
row
.
options
());
...
...
@@ -18,23 +20,23 @@ at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
at
::
Tensor
indexB
,
at
::
Tensor
valueB
,
size_t
rowA_max
,
size_t
rowB_max
)
{
int64_t
*
index_data
=
index
.
data
<
int64_t
>
();
int64_t
*
index_data
=
index
.
DATA_PTR
<
int64_t
>
();
auto
value
=
at
::
zeros
(
index
.
size
(
1
),
valueA
.
options
());
at
::
Tensor
rowA
,
colA
;
std
::
tie
(
rowA
,
colA
)
=
to_csr
(
indexA
[
0
],
indexA
[
1
],
rowA_max
);
int64_t
*
rowA_data
=
rowA
.
data
<
int64_t
>
();
int64_t
*
colA_data
=
colA
.
data
<
int64_t
>
();
int64_t
*
rowA_data
=
rowA
.
DATA_PTR
<
int64_t
>
();
int64_t
*
colA_data
=
colA
.
DATA_PTR
<
int64_t
>
();
at
::
Tensor
rowB
,
colB
;
std
::
tie
(
rowB
,
colB
)
=
to_csr
(
indexB
[
0
],
indexB
[
1
],
rowB_max
);
int64_t
*
rowB_data
=
rowB
.
data
<
int64_t
>
();
int64_t
*
colB_data
=
colB
.
data
<
int64_t
>
();
int64_t
*
rowB_data
=
rowB
.
DATA_PTR
<
int64_t
>
();
int64_t
*
colB_data
=
colB
.
DATA_PTR
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
valueA
.
scalar_type
(),
"spspmm_bw"
,
[
&
]
{
scalar_t
*
value_data
=
value
.
data
<
scalar_t
>
();
scalar_t
*
valueA_data
=
valueA
.
data
<
scalar_t
>
();
scalar_t
*
valueB_data
=
valueB
.
data
<
scalar_t
>
();
scalar_t
*
value_data
=
value
.
DATA_PTR
<
scalar_t
>
();
scalar_t
*
valueA_data
=
valueA
.
DATA_PTR
<
scalar_t
>
();
scalar_t
*
valueB_data
=
valueB
.
DATA_PTR
<
scalar_t
>
();
for
(
int64_t
e
=
0
;
e
<
value
.
size
(
0
);
e
++
)
{
int64_t
i
=
index_data
[
e
],
j
=
index_data
[
value
.
size
(
0
)
+
e
];
...
...
cuda/compat.cuh
0 → 100644
View file @
1c4fdfe2
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
cuda/spspmm_kernel.cu
View file @
1c4fdfe2
#include <ATen/ATen.h>
#include <cusparse.h>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
...
...
@@ -51,18 +52,18 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
// Convert A to CSR format.
auto
row_ptrA
=
at
::
empty
(
m
+
1
,
indexA
.
options
());
cusparseXcoo2csr
(
cusparse_handle
,
indexA
[
0
].
data
<
int
>
(),
nnzA
,
k
,
row_ptrA
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
cusparseXcoo2csr
(
cusparse_handle
,
indexA
[
0
].
DATA_PTR
<
int
>
(),
nnzA
,
k
,
row_ptrA
.
DATA_PTR
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
colA
=
indexA
[
1
];
cudaMemcpy
(
row_ptrA
.
data
<
int
>
()
+
m
,
&
nnzA
,
sizeof
(
int
),
cudaMemcpy
(
row_ptrA
.
DATA_PTR
<
int
>
()
+
m
,
&
nnzA
,
sizeof
(
int
),
cudaMemcpyHostToDevice
);
// Convert B to CSR format.
auto
row_ptrB
=
at
::
empty
(
k
+
1
,
indexB
.
options
());
cusparseXcoo2csr
(
cusparse_handle
,
indexB
[
0
].
data
<
int
>
(),
nnzB
,
k
,
row_ptrB
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
cusparseXcoo2csr
(
cusparse_handle
,
indexB
[
0
].
DATA_PTR
<
int
>
(),
nnzB
,
k
,
row_ptrB
.
DATA_PTR
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
colB
=
indexB
[
1
];
cudaMemcpy
(
row_ptrB
.
data
<
int
>
()
+
k
,
&
nnzB
,
sizeof
(
int
),
cudaMemcpy
(
row_ptrB
.
DATA_PTR
<
int
>
()
+
k
,
&
nnzB
,
sizeof
(
int
),
cudaMemcpyHostToDevice
);
cusparseMatDescr_t
descr
=
0
;
...
...
@@ -74,22 +75,23 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
auto
row_ptrC
=
at
::
empty
(
m
+
1
,
indexB
.
options
());
cusparseXcsrgemmNnz
(
cusparse_handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
m
,
n
,
k
,
descr
,
nnzA
,
row_ptrA
.
data
<
int
>
(),
colA
.
data
<
int
>
(),
descr
,
nnzB
,
row_ptrB
.
data
<
int
>
(),
colB
.
data
<
int
>
(),
descr
,
row_ptrC
.
data
<
int
>
(),
&
nnzC
);
row_ptrA
.
DATA_PTR
<
int
>
(),
colA
.
DATA_PTR
<
int
>
(),
descr
,
nnzB
,
row_ptrB
.
DATA_PTR
<
int
>
(),
colB
.
DATA_PTR
<
int
>
(),
descr
,
row_ptrC
.
DATA_PTR
<
int
>
(),
&
nnzC
);
auto
colC
=
at
::
empty
(
nnzC
,
indexA
.
options
());
auto
valueC
=
at
::
empty
(
nnzC
,
valueA
.
options
());
CSRGEMM
(
valueC
.
scalar_type
(),
cusparse_handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
m
,
n
,
k
,
descr
,
nnzA
,
valueA
.
data
<
scalar_t
>
(),
row_ptrA
.
data
<
int
>
(),
colA
.
data
<
int
>
(),
descr
,
nnzB
,
valueB
.
data
<
scalar_t
>
(),
row_ptrB
.
data
<
int
>
(),
colB
.
data
<
int
>
(),
descr
,
valueC
.
data
<
scalar_t
>
(),
row_ptrC
.
data
<
int
>
(),
colC
.
data
<
int
>
());
n
,
k
,
descr
,
nnzA
,
valueA
.
DATA_PTR
<
scalar_t
>
(),
row_ptrA
.
DATA_PTR
<
int
>
(),
colA
.
DATA_PTR
<
int
>
(),
descr
,
nnzB
,
valueB
.
DATA_PTR
<
scalar_t
>
(),
row_ptrB
.
DATA_PTR
<
int
>
(),
colB
.
DATA_PTR
<
int
>
(),
descr
,
valueC
.
DATA_PTR
<
scalar_t
>
(),
row_ptrC
.
DATA_PTR
<
int
>
(),
colC
.
DATA_PTR
<
int
>
());
auto
rowC
=
at
::
empty
(
nnzC
,
indexA
.
options
());
cusparseXcsr2coo
(
cusparse_handle
,
row_ptrC
.
data
<
int
>
(),
nnzC
,
m
,
rowC
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
cusparseXcsr2coo
(
cusparse_handle
,
row_ptrC
.
DATA_PTR
<
int
>
(),
nnzC
,
m
,
rowC
.
DATA_PTR
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
indexC
=
at
::
stack
({
rowC
,
colC
},
0
).
toType
(
at
::
kLong
);
...
...
@@ -154,9 +156,10 @@ at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
AT_DISPATCH_FLOATING_TYPES
(
valueA
.
scalar_type
(),
"spspmm_bw"
,
[
&
]
{
spspmm_bw_kernel
<
scalar_t
><<<
BLOCKS
(
value
.
numel
()),
THREADS
>>>
(
index
.
data
<
int64_t
>
(),
value
.
data
<
scalar_t
>
(),
rowA
.
data
<
int64_t
>
(),
colA
.
data
<
int64_t
>
(),
valueA
.
data
<
scalar_t
>
(),
rowB
.
data
<
int64_t
>
(),
colB
.
data
<
int64_t
>
(),
valueB
.
data
<
scalar_t
>
(),
value
.
numel
());
index
.
DATA_PTR
<
int64_t
>
(),
value
.
DATA_PTR
<
scalar_t
>
(),
rowA
.
DATA_PTR
<
int64_t
>
(),
colA
.
DATA_PTR
<
int64_t
>
(),
valueA
.
DATA_PTR
<
scalar_t
>
(),
rowB
.
DATA_PTR
<
int64_t
>
(),
colB
.
DATA_PTR
<
int64_t
>
(),
valueB
.
DATA_PTR
<
scalar_t
>
(),
value
.
numel
());
});
return
value
;
...
...
cuda/unique_kernel.cu
View file @
1c4fdfe2
#include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
...
...
@@ -23,7 +25,7 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
auto
mask
=
at
::
zeros
(
src
.
numel
(),
src
.
options
().
dtype
(
at
::
kByte
));
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"grid_cuda_kernel"
,
[
&
]
{
unique_cuda_kernel
<
scalar_t
><<<
BLOCKS
(
src
.
numel
()),
THREADS
>>>
(
src
.
data
<
scalar_t
>
(),
mask
.
data
<
uint8_t
>
(),
src
.
numel
());
src
.
DATA_PTR
<
scalar_t
>
(),
mask
.
DATA_PTR
<
uint8_t
>
(),
src
.
numel
());
});
src
=
src
.
masked_select
(
mask
);
...
...
setup.py
View file @
1c4fdfe2
...
...
@@ -3,7 +3,17 @@ from setuptools import setup, find_packages
import
torch
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
ext_modules
=
[
CppExtension
(
'torch_sparse.spspmm_cpu'
,
[
'cpu/spspmm.cpp'
])]
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
extra_compile_args
=
[]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
2
):
extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
ext_modules
=
[
CppExtension
(
'torch_sparse.spspmm_cpu'
,
[
'cpu/spspmm.cpp'
],
extra_compile_args
=
extra_compile_args
)
]
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
if
CUDA_HOME
is
not
None
:
...
...
@@ -13,15 +23,16 @@ if CUDA_HOME is not None:
extra_link_args
=
[
'-lcusparse'
,
'-l'
,
'cusparse'
]
ext_modules
+=
[
CUDAExtension
(
'torch_sparse.spspmm_cuda'
,
CUDAExtension
(
'torch_sparse.spspmm_cuda'
,
[
'cuda/spspmm.cpp'
,
'cuda/spspmm_kernel.cu'
],
extra_link_args
=
extra_link_args
),
extra_link_args
=
extra_link_args
,
extra_compile_args
=
extra_compile_args
),
CUDAExtension
(
'torch_sparse.unique_cuda'
,
[
'cuda/unique.cpp'
,
'cuda/unique_kernel.cu'
]),
[
'cuda/unique.cpp'
,
'cuda/unique_kernel.cu'
],
extra_compile_args
=
extra_compile_args
),
]
__version__
=
'0.4.
0
'
__version__
=
'0.4.
1
'
url
=
'https://github.com/rusty1s/pytorch_sparse'
install_requires
=
[
'scipy'
]
...
...
torch_sparse/__init__.py
View file @
1c4fdfe2
...
...
@@ -5,7 +5,7 @@ from .eye import eye
from
.spmm
import
spmm
from
.spspmm
import
spspmm
__version__
=
'0.4.
0
'
__version__
=
'0.4.
1
'
__all__
=
[
'__version__'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment