Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
de528831
Commit
de528831
authored
May 01, 2019
by
rusty1s
Browse files
pytorch 1.1.0 update
parent
9732a518
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
27 additions
and
25 deletions
+27
-25
.travis.yml
.travis.yml
+3
-3
README.md
README.md
+1
-1
cpu/spspmm.cpp
cpu/spspmm.cpp
+1
-1
cuda/spspmm_kernel.cu
cuda/spspmm_kernel.cu
+18
-16
cuda/unique_kernel.cu
cuda/unique_kernel.cu
+2
-2
setup.py
setup.py
+1
-1
torch_sparse/__init__.py
torch_sparse/__init__.py
+1
-1
No files found.
.travis.yml
View file @
de528831
...
...
@@ -17,9 +17,9 @@ before_install:
-
export CC="gcc-4.9"
-
export CXX="g++-4.9"
install
:
-
if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.
0
.0-cp27-cp27mu-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.
0
.0-cp35-cp35m-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.
0
.0-cp36-cp36m-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.
1
.0-cp27-cp27mu-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.
1
.0-cp35-cp35m-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.
1
.0-cp36-cp36m-linux_x86_64.whl; fi
-
pip install pycodestyle
-
pip install flake8
-
pip install codecov
...
...
README.md
View file @
de528831
...
...
@@ -28,7 +28,7 @@ Note that only `value` comes with autograd support, as `index` is discrete and t
## Installation
Ensure that at least PyTorch 1.
0
.0 is installed and verify that
`cuda/bin`
and
`cuda/include`
are in your
`$PATH`
and
`$CPATH`
respectively,
*e.g.*
:
Ensure that at least PyTorch 1.
1
.0 is installed and verify that
`cuda/bin`
and
`cuda/include`
are in your
`$PATH`
and
`$CPATH`
respectively,
*e.g.*
:
```
$ python -c "import torch; print(torch.__version__)"
...
...
cpu/spspmm.cpp
View file @
de528831
...
...
@@ -31,7 +31,7 @@ at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
int64_t
*
rowB_data
=
rowB
.
data
<
int64_t
>
();
int64_t
*
colB_data
=
colB
.
data
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
valueA
.
type
(),
"spspmm_bw"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
valueA
.
scalar_
type
(),
"spspmm_bw"
,
[
&
]
{
scalar_t
*
value_data
=
value
.
data
<
scalar_t
>
();
scalar_t
*
valueA_data
=
valueA
.
data
<
scalar_t
>
();
scalar_t
*
valueB_data
=
valueB
.
data
<
scalar_t
>
();
...
...
cuda/spspmm_kernel.cu
View file @
de528831
...
...
@@ -7,8 +7,10 @@
#define CSRGEMM(TYPE, ...) \
[&] { \
const at::Type &the_type = TYPE; \
switch (the_type.scalarType()) { \
const auto &the_type = TYPE; \
(void)the_type; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
return cusparseScsrgemm(__VA_ARGS__); \
...
...
@@ -18,7 +20,7 @@
return cusparseDcsrgemm(__VA_ARGS__); \
} \
default: \
AT_ERROR("Not implemented for '
%s'", the_type.toString());
\
AT_ERROR("Not implemented for '
", toString(_st), "'");
\
} \
}()
...
...
@@ -48,7 +50,7 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
indexB
=
indexB
.
toType
(
at
::
kInt
);
// Convert A to CSR format.
auto
row_ptrA
=
at
::
empty
(
m
+
1
,
indexA
.
type
());
auto
row_ptrA
=
at
::
empty
(
m
+
1
,
indexA
.
options
());
cusparseXcoo2csr
(
cusparse_handle
,
indexA
[
0
].
data
<
int
>
(),
nnzA
,
k
,
row_ptrA
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
colA
=
indexA
[
1
];
...
...
@@ -56,7 +58,7 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
cudaMemcpyHostToDevice
);
// Convert B to CSR format.
auto
row_ptrB
=
at
::
empty
(
k
+
1
,
indexB
.
type
());
auto
row_ptrB
=
at
::
empty
(
k
+
1
,
indexB
.
options
());
cusparseXcoo2csr
(
cusparse_handle
,
indexB
[
0
].
data
<
int
>
(),
nnzB
,
k
,
row_ptrB
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
colB
=
indexB
[
1
];
...
...
@@ -69,23 +71,23 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
cusparseSetMatIndexBase
(
descr
,
CUSPARSE_INDEX_BASE_ZERO
);
int
nnzC
;
auto
row_ptrC
=
at
::
empty
(
m
+
1
,
indexB
.
type
());
auto
row_ptrC
=
at
::
empty
(
m
+
1
,
indexB
.
options
());
cusparseXcsrgemmNnz
(
cusparse_handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
m
,
n
,
k
,
descr
,
nnzA
,
row_ptrA
.
data
<
int
>
(),
colA
.
data
<
int
>
(),
descr
,
nnzB
,
row_ptrB
.
data
<
int
>
(),
colB
.
data
<
int
>
(),
descr
,
row_ptrC
.
data
<
int
>
(),
&
nnzC
);
auto
colC
=
at
::
empty
(
nnzC
,
indexA
.
type
());
auto
valueC
=
at
::
empty
(
nnzC
,
valueA
.
type
());
auto
colC
=
at
::
empty
(
nnzC
,
indexA
.
options
());
auto
valueC
=
at
::
empty
(
nnzC
,
valueA
.
options
());
CSRGEMM
(
valueC
.
type
(),
cusparse_handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
m
,
n
,
k
,
descr
,
nnzA
,
valueA
.
data
<
scalar_t
>
(),
row_ptrA
.
data
<
int
>
(),
colA
.
data
<
int
>
(),
descr
,
nnzB
,
valueB
.
data
<
scalar_t
>
(),
row_ptrB
.
data
<
int
>
(),
col
B
.
data
<
int
>
(),
descr
,
valueC
.
data
<
scalar_t
>
()
,
row_ptrC
.
data
<
int
>
(),
colC
.
data
<
int
>
());
CSRGEMM
(
valueC
.
scalar_
type
(),
cusparse_handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
m
,
n
,
k
,
descr
,
nnzA
,
valueA
.
data
<
scalar_t
>
(),
row_ptrA
.
data
<
int
>
(),
colA
.
data
<
int
>
(),
descr
,
nnzB
,
valueB
.
data
<
scalar_t
>
(),
row_ptr
B
.
data
<
int
>
(),
colB
.
data
<
int
>
(),
descr
,
valueC
.
data
<
scalar_t
>
(),
row_ptrC
.
data
<
int
>
(),
colC
.
data
<
int
>
());
auto
rowC
=
at
::
empty
(
nnzC
,
indexA
.
type
());
auto
rowC
=
at
::
empty
(
nnzC
,
indexA
.
options
());
cusparseXcsr2coo
(
cusparse_handle
,
row_ptrC
.
data
<
int
>
(),
nnzC
,
m
,
rowC
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
...
...
@@ -150,7 +152,7 @@ at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
at
::
Tensor
rowB
,
colB
;
std
::
tie
(
rowB
,
colB
)
=
to_csr
(
indexB
[
0
],
indexB
[
1
],
rowB_max
);
AT_DISPATCH_FLOATING_TYPES
(
valueA
.
type
(),
"spspmm_bw"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
valueA
.
scalar_
type
(),
"spspmm_bw"
,
[
&
]
{
spspmm_bw_kernel
<
scalar_t
><<<
BLOCKS
(
value
.
numel
()),
THREADS
>>>
(
index
.
data
<
int64_t
>
(),
value
.
data
<
scalar_t
>
(),
rowA
.
data
<
int64_t
>
(),
colA
.
data
<
int64_t
>
(),
valueA
.
data
<
scalar_t
>
(),
rowB
.
data
<
int64_t
>
(),
...
...
cuda/unique_kernel.cu
View file @
de528831
...
...
@@ -20,8 +20,8 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
at
::
Tensor
perm
;
std
::
tie
(
src
,
perm
)
=
src
.
sort
();
auto
mask
=
at
::
zeros
(
src
.
numel
(),
src
.
type
().
toScalarT
ype
(
at
::
kByte
));
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"grid_cuda_kernel"
,
[
&
]
{
auto
mask
=
at
::
zeros
(
src
.
numel
(),
src
.
options
().
dt
ype
(
at
::
kByte
));
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_
type
(),
"grid_cuda_kernel"
,
[
&
]
{
unique_cuda_kernel
<
scalar_t
><<<
BLOCKS
(
src
.
numel
()),
THREADS
>>>
(
src
.
data
<
scalar_t
>
(),
mask
.
data
<
uint8_t
>
(),
src
.
numel
());
});
...
...
setup.py
View file @
de528831
...
...
@@ -21,7 +21,7 @@ if CUDA_HOME is not None:
[
'cuda/unique.cpp'
,
'cuda/unique_kernel.cu'
]),
]
__version__
=
'0.
3
.0'
__version__
=
'0.
4
.0'
url
=
'https://github.com/rusty1s/pytorch_sparse'
install_requires
=
[
'scipy'
]
...
...
torch_sparse/__init__.py
View file @
de528831
...
...
@@ -5,7 +5,7 @@ from .eye import eye
from
.spmm
import
spmm
from
.spspmm
import
spspmm
__version__
=
'0.
3
.0'
__version__
=
'0.
4
.0'
__all__
=
[
'__version__'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment