Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
cb53126c
Unverified
Commit
cb53126c
authored
Dec 22, 2022
by
Matthias Fey
Committed by
GitHub
Dec 22, 2022
Browse files
Drop `cusparse` (#302)
* update * update
parent
955b1cf3
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
35 additions
and
370 deletions
+35
-370
.github/workflows/building-conda.yml
.github/workflows/building-conda.yml
+1
-3
CMakeLists.txt
CMakeLists.txt
+1
-6
conda/pytorch-sparse/meta.yaml
conda/pytorch-sparse/meta.yaml
+1
-1
csrc/cpu/spspmm_cpu.cpp
csrc/cpu/spspmm_cpu.cpp
+0
-106
csrc/cpu/spspmm_cpu.h
csrc/cpu/spspmm_cpu.h
+0
-10
csrc/cuda/spspmm_cuda.cu
csrc/cuda/spspmm_cuda.cu
+0
-147
csrc/cuda/spspmm_cuda.h
csrc/cuda/spspmm_cuda.h
+0
-10
csrc/sparse.h
csrc/sparse.h
+0
-7
csrc/spspmm.cpp
csrc/spspmm.cpp
+0
-41
setup.py
setup.py
+4
-15
test/test_matmul.py
test/test_matmul.py
+4
-2
test/test_spspmm.py
test/test_spspmm.py
+2
-2
torch_sparse/__init__.py
torch_sparse/__init__.py
+3
-4
torch_sparse/matmul.py
torch_sparse/matmul.py
+19
-16
No files found.
.github/workflows/building-conda.yml
View file @
cb53126c
...
@@ -13,7 +13,7 @@ jobs:
...
@@ -13,7 +13,7 @@ jobs:
# We have trouble building for Windows - drop for now.
# We have trouble building for Windows - drop for now.
os
:
[
ubuntu-18.04
,
macos-10.15
]
# windows-2019
os
:
[
ubuntu-18.04
,
macos-10.15
]
# windows-2019
python-version
:
[
'
3.7'
,
'
3.8'
,
'
3.9'
,
'
3.10'
]
python-version
:
[
'
3.7'
,
'
3.8'
,
'
3.9'
,
'
3.10'
]
torch-version
:
[
1.13.0
]
#
[1.12.0, 1.13.0]
torch-version
:
[
1.12.0
,
1.13.0
]
cuda-version
:
[
'
cpu'
,
'
cu102'
,
'
cu113'
,
'
cu116'
,
'
cu117'
]
cuda-version
:
[
'
cpu'
,
'
cu102'
,
'
cu113'
,
'
cu116'
,
'
cu117'
]
exclude
:
exclude
:
-
torch-version
:
1.12.0
-
torch-version
:
1.12.0
...
@@ -32,8 +32,6 @@ jobs:
...
@@ -32,8 +32,6 @@ jobs:
cuda-version
:
'
cu117'
cuda-version
:
'
cu117'
-
os
:
windows-2019
-
os
:
windows-2019
cuda-version
:
'
cu102'
cuda-version
:
'
cu102'
-
os
:
windows-2019
# Complains about CUDA mismatch.
python-version
:
'
3.7'
steps
:
steps
:
-
uses
:
actions/checkout@v2
-
uses
:
actions/checkout@v2
...
...
CMakeLists.txt
View file @
cb53126c
cmake_minimum_required
(
VERSION 3.10
)
cmake_minimum_required
(
VERSION 3.10
)
project
(
torchsparse
)
project
(
torchsparse
)
set
(
CMAKE_CXX_STANDARD 14
)
set
(
CMAKE_CXX_STANDARD 14
)
set
(
TORCHSPARSE_VERSION 0.6.1
5
)
set
(
TORCHSPARSE_VERSION 0.6.1
6
)
set
(
CMAKE_MODULE_PATH
${
CMAKE_MODULE_PATH
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/cmake
)
set
(
CMAKE_MODULE_PATH
${
CMAKE_MODULE_PATH
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/cmake
)
option
(
WITH_CUDA
"Enable CUDA support"
OFF
)
option
(
WITH_CUDA
"Enable CUDA support"
OFF
)
...
@@ -34,9 +34,6 @@ endif()
...
@@ -34,9 +34,6 @@ endif()
add_library
(
${
PROJECT_NAME
}
SHARED
${
OPERATOR_SOURCES
}
)
add_library
(
${
PROJECT_NAME
}
SHARED
${
OPERATOR_SOURCES
}
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
${
TORCH_LIBRARIES
}
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
${
TORCH_LIBRARIES
}
)
if
(
WITH_CUDA
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
${
CUDA_cusparse_LIBRARY
}
)
endif
()
if
(
WITH_PYTHON
)
if
(
WITH_PYTHON
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE Python3::Python
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE Python3::Python
)
endif
()
endif
()
...
@@ -95,7 +92,6 @@ install(FILES
...
@@ -95,7 +92,6 @@ install(FILES
csrc/cpu/saint_cpu.h
csrc/cpu/saint_cpu.h
csrc/cpu/sample_cpu.h
csrc/cpu/sample_cpu.h
csrc/cpu/spmm_cpu.h
csrc/cpu/spmm_cpu.h
csrc/cpu/spspmm_cpu.h
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/
${
PROJECT_NAME
}
/cpu
)
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/
${
PROJECT_NAME
}
/cpu
)
if
(
WITH_CUDA
)
if
(
WITH_CUDA
)
install
(
FILES
install
(
FILES
...
@@ -103,7 +99,6 @@ if(WITH_CUDA)
...
@@ -103,7 +99,6 @@ if(WITH_CUDA)
csrc/cuda/diag_cuda.h
csrc/cuda/diag_cuda.h
csrc/cuda/rw_cuda.h
csrc/cuda/rw_cuda.h
csrc/cuda/spmm_cuda.h
csrc/cuda/spmm_cuda.h
csrc/cuda/spspmm_cuda.h
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/
${
PROJECT_NAME
}
/cuda
)
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/
${
PROJECT_NAME
}
/cuda
)
endif
()
endif
()
...
...
conda/pytorch-sparse/meta.yaml
View file @
cb53126c
package
:
package
:
name
:
pytorch-sparse
name
:
pytorch-sparse
version
:
0.6.1
5
version
:
0.6.1
6
source
:
source
:
path
:
../..
path
:
../..
...
...
csrc/cpu/spspmm_cpu.cpp
deleted
100644 → 0
View file @
955b1cf3
#include "spspmm_cpu.h"
#include "utils.h"
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
spspmm_cpu
(
torch
::
Tensor
rowptrA
,
torch
::
Tensor
colA
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueA
,
torch
::
Tensor
rowptrB
,
torch
::
Tensor
colB
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueB
,
int64_t
K
,
std
::
string
reduce
)
{
CHECK_CPU
(
rowptrA
);
CHECK_CPU
(
colA
);
if
(
optional_valueA
.
has_value
())
CHECK_CPU
(
optional_valueA
.
value
());
CHECK_CPU
(
rowptrB
);
CHECK_CPU
(
colB
);
if
(
optional_valueB
.
has_value
())
CHECK_CPU
(
optional_valueB
.
value
());
CHECK_INPUT
(
rowptrA
.
dim
()
==
1
);
CHECK_INPUT
(
colA
.
dim
()
==
1
);
if
(
optional_valueA
.
has_value
())
{
CHECK_INPUT
(
optional_valueA
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_valueA
.
value
().
size
(
0
)
==
colA
.
size
(
0
));
}
CHECK_INPUT
(
rowptrB
.
dim
()
==
1
);
CHECK_INPUT
(
colB
.
dim
()
==
1
);
if
(
optional_valueB
.
has_value
())
{
CHECK_INPUT
(
optional_valueB
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_valueB
.
value
().
size
(
0
)
==
colB
.
size
(
0
));
}
if
(
!
optional_valueA
.
has_value
()
&&
optional_valueB
.
has_value
())
optional_valueA
=
torch
::
ones
({
colA
.
numel
()},
optional_valueB
.
value
().
options
());
if
(
!
optional_valueB
.
has_value
()
&&
optional_valueA
.
has_value
())
optional_valueB
=
torch
::
ones
({
colB
.
numel
()},
optional_valueA
.
value
().
options
());
auto
scalar_type
=
torch
::
ScalarType
::
Float
;
if
(
optional_valueA
.
has_value
())
scalar_type
=
optional_valueA
.
value
().
scalar_type
();
auto
rowptrA_data
=
rowptrA
.
data_ptr
<
int64_t
>
();
auto
colA_data
=
colA
.
data_ptr
<
int64_t
>
();
auto
rowptrB_data
=
rowptrB
.
data_ptr
<
int64_t
>
();
auto
colB_data
=
colB
.
data_ptr
<
int64_t
>
();
auto
rowptrC
=
torch
::
empty_like
(
rowptrA
);
auto
rowptrC_data
=
rowptrC
.
data_ptr
<
int64_t
>
();
rowptrC_data
[
0
]
=
0
;
torch
::
Tensor
colC
;
torch
::
optional
<
torch
::
Tensor
>
optional_valueC
=
torch
::
nullopt
;
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
BFloat16
,
scalar_type
,
"spspmm"
,
[
&
]
{
AT_DISPATCH_HAS_VALUE
(
optional_valueA
,
[
&
]
{
scalar_t
*
valA_data
=
nullptr
,
*
valB_data
=
nullptr
;
if
(
HAS_VALUE
)
{
valA_data
=
optional_valueA
.
value
().
data_ptr
<
scalar_t
>
();
valB_data
=
optional_valueB
.
value
().
data_ptr
<
scalar_t
>
();
}
int64_t
nnz
=
0
,
cA
,
cB
;
std
::
vector
<
scalar_t
>
tmp_vals
(
K
,
0
);
std
::
vector
<
int64_t
>
cols
;
std
::
vector
<
scalar_t
>
vals
;
for
(
auto
rA
=
0
;
rA
<
rowptrA
.
numel
()
-
1
;
rA
++
)
{
for
(
auto
eA
=
rowptrA_data
[
rA
];
eA
<
rowptrA_data
[
rA
+
1
];
eA
++
)
{
cA
=
colA_data
[
eA
];
for
(
auto
eB
=
rowptrB_data
[
cA
];
eB
<
rowptrB_data
[
cA
+
1
];
eB
++
)
{
cB
=
colB_data
[
eB
];
if
(
HAS_VALUE
)
tmp_vals
[
cB
]
+=
valA_data
[
eA
]
*
valB_data
[
eB
];
else
tmp_vals
[
cB
]
+=
1
;
}
}
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
{
if
(
tmp_vals
[
k
]
!=
0
)
{
cols
.
push_back
(
k
);
if
(
HAS_VALUE
)
vals
.
push_back
(
tmp_vals
[
k
]);
nnz
++
;
}
tmp_vals
[
k
]
=
(
scalar_t
)
0
;
}
rowptrC_data
[
rA
+
1
]
=
nnz
;
}
colC
=
torch
::
from_blob
(
cols
.
data
(),
{
nnz
},
colA
.
options
()).
clone
();
if
(
HAS_VALUE
)
{
optional_valueC
=
torch
::
from_blob
(
vals
.
data
(),
{
nnz
},
optional_valueA
.
value
().
options
());
optional_valueC
=
optional_valueC
.
value
().
clone
();
}
});
});
return
std
::
make_tuple
(
rowptrC
,
colC
,
optional_valueC
);
}
csrc/cpu/spspmm_cpu.h
deleted
100644 → 0
View file @
955b1cf3
#pragma once
#include "../extensions.h"
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
spspmm_cpu
(
torch
::
Tensor
rowptrA
,
torch
::
Tensor
colA
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueA
,
torch
::
Tensor
rowptrB
,
torch
::
Tensor
colB
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueB
,
int64_t
K
,
std
::
string
reduce
);
csrc/cuda/spspmm_cuda.cu
deleted
100644 → 0
View file @
955b1cf3
#include "spspmm_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include <cusparse.h>
#include "utils.cuh"
#define AT_DISPATCH_CUSPARSE_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case torch::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseScsrgemm2_bufferSizeExt; \
const auto &cusparsecsrgemm2 = cusparseScsrgemm2; \
return __VA_ARGS__(); \
} \
case torch::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseDcsrgemm2_bufferSizeExt; \
const auto &cusparsecsrgemm2 = cusparseDcsrgemm2; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
spspmm_cuda
(
torch
::
Tensor
rowptrA
,
torch
::
Tensor
colA
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueA
,
torch
::
Tensor
rowptrB
,
torch
::
Tensor
colB
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueB
,
int64_t
K
,
std
::
string
reduce
)
{
CHECK_CUDA
(
rowptrA
);
CHECK_CUDA
(
colA
);
if
(
optional_valueA
.
has_value
())
CHECK_CUDA
(
optional_valueA
.
value
());
CHECK_CUDA
(
rowptrB
);
CHECK_CUDA
(
colB
);
if
(
optional_valueB
.
has_value
())
CHECK_CUDA
(
optional_valueB
.
value
());
cudaSetDevice
(
rowptrA
.
get_device
());
CHECK_INPUT
(
rowptrA
.
dim
()
==
1
);
CHECK_INPUT
(
colA
.
dim
()
==
1
);
if
(
optional_valueA
.
has_value
())
{
CHECK_INPUT
(
optional_valueA
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_valueA
.
value
().
size
(
0
)
==
colA
.
size
(
0
));
}
CHECK_INPUT
(
rowptrB
.
dim
()
==
1
);
CHECK_INPUT
(
colB
.
dim
()
==
1
);
if
(
optional_valueB
.
has_value
())
{
CHECK_INPUT
(
optional_valueB
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_valueB
.
value
().
size
(
0
)
==
colB
.
size
(
0
));
}
if
(
!
optional_valueA
.
has_value
()
&&
optional_valueB
.
has_value
())
optional_valueA
=
torch
::
ones
({
colA
.
numel
()},
optional_valueB
.
value
().
options
());
if
(
!
optional_valueB
.
has_value
()
&&
optional_valueA
.
has_value
())
optional_valueB
=
torch
::
ones
({
colB
.
numel
()},
optional_valueA
.
value
().
options
());
auto
scalar_type
=
torch
::
ScalarType
::
Float
;
if
(
optional_valueA
.
has_value
())
scalar_type
=
optional_valueA
.
value
().
scalar_type
();
auto
handle
=
at
::
cuda
::
getCurrentCUDASparseHandle
();
cusparseMatDescr_t
descr
;
cusparseCreateMatDescr
(
&
descr
);
rowptrA
=
rowptrA
.
toType
(
torch
::
kInt
);
colA
=
colA
.
toType
(
torch
::
kInt
);
rowptrB
=
rowptrB
.
toType
(
torch
::
kInt
);
colB
=
colB
.
toType
(
torch
::
kInt
);
int64_t
M
=
rowptrA
.
numel
()
-
1
,
N
=
rowptrB
.
numel
()
-
1
;
auto
rowptrA_data
=
rowptrA
.
data_ptr
<
int
>
();
auto
colA_data
=
colA
.
data_ptr
<
int
>
();
auto
rowptrB_data
=
rowptrB
.
data_ptr
<
int
>
();
auto
colB_data
=
colB
.
data_ptr
<
int
>
();
torch
::
Tensor
rowptrC
,
colC
;
torch
::
optional
<
torch
::
Tensor
>
optional_valueC
=
torch
::
nullopt
;
int
nnzC
;
int
*
nnzTotalDevHostPtr
=
&
nnzC
;
// Step 1: Create an opaque structure.
csrgemm2Info_t
info
=
NULL
;
cusparseCreateCsrgemm2Info
(
&
info
);
// Step 2: Allocate buffer for `csrgemm2Nnz` and `csrgemm2`.
size_t
bufferSize
;
AT_DISPATCH_CUSPARSE_TYPES
(
scalar_type
,
[
&
]
{
scalar_t
alpha
=
(
scalar_t
)
1.0
;
cusparsecsrgemm2_bufferSizeExt
(
handle
,
M
,
N
,
K
,
&
alpha
,
descr
,
colA
.
numel
(),
rowptrA_data
,
colA_data
,
descr
,
colB
.
numel
(),
rowptrB_data
,
colB_data
,
NULL
,
descr
,
0
,
NULL
,
NULL
,
info
,
&
bufferSize
);
void
*
buffer
=
NULL
;
cudaMalloc
(
&
buffer
,
bufferSize
);
// Step 3: Compute CSR row pointer.
rowptrC
=
torch
::
empty
({
M
+
1
},
rowptrA
.
options
());
auto
rowptrC_data
=
rowptrC
.
data_ptr
<
int
>
();
cusparseXcsrgemm2Nnz
(
handle
,
M
,
N
,
K
,
descr
,
colA
.
numel
(),
rowptrA_data
,
colA_data
,
descr
,
colB
.
numel
(),
rowptrB_data
,
colB_data
,
descr
,
0
,
NULL
,
NULL
,
descr
,
rowptrC_data
,
nnzTotalDevHostPtr
,
info
,
buffer
);
// Step 4: Compute CSR entries.
colC
=
torch
::
empty
({
nnzC
},
rowptrC
.
options
());
auto
colC_data
=
colC
.
data_ptr
<
int
>
();
if
(
optional_valueA
.
has_value
())
optional_valueC
=
torch
::
empty
({
nnzC
},
optional_valueA
.
value
().
options
());
scalar_t
*
valA_data
=
NULL
,
*
valB_data
=
NULL
,
*
valC_data
=
NULL
;
if
(
optional_valueA
.
has_value
())
{
valA_data
=
optional_valueA
.
value
().
data_ptr
<
scalar_t
>
();
valB_data
=
optional_valueB
.
value
().
data_ptr
<
scalar_t
>
();
valC_data
=
optional_valueC
.
value
().
data_ptr
<
scalar_t
>
();
}
cusparsecsrgemm2
(
handle
,
M
,
N
,
K
,
&
alpha
,
descr
,
colA
.
numel
(),
valA_data
,
rowptrA_data
,
colA_data
,
descr
,
colB
.
numel
(),
valB_data
,
rowptrB_data
,
colB_data
,
NULL
,
descr
,
0
,
NULL
,
NULL
,
NULL
,
descr
,
valC_data
,
rowptrC_data
,
colC_data
,
info
,
buffer
);
cudaFree
(
buffer
);
});
// Step 5: Destroy the opaque structure.
cusparseDestroyCsrgemm2Info
(
info
);
rowptrC
=
rowptrC
.
toType
(
torch
::
kLong
);
colC
=
colC
.
toType
(
torch
::
kLong
);
return
std
::
make_tuple
(
rowptrC
,
colC
,
optional_valueC
);
}
csrc/cuda/spspmm_cuda.h
deleted
100644 → 0
View file @
955b1cf3
#pragma once
#include "../extensions.h"
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
spspmm_cuda
(
torch
::
Tensor
rowptrA
,
torch
::
Tensor
colA
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueA
,
torch
::
Tensor
rowptrB
,
torch
::
Tensor
colB
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueB
,
int64_t
K
,
std
::
string
reduce
);
csrc/sparse.h
View file @
cb53126c
...
@@ -74,10 +74,3 @@ spmm_min(torch::Tensor rowptr, torch::Tensor col,
...
@@ -74,10 +74,3 @@ spmm_min(torch::Tensor rowptr, torch::Tensor col,
SPARSE_API
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
SPARSE_API
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
spmm_max
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
spmm_max
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
opt_value
,
torch
::
Tensor
mat
);
torch
::
optional
<
torch
::
Tensor
>
opt_value
,
torch
::
Tensor
mat
);
SPARSE_API
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
spspmm_sum
(
torch
::
Tensor
rowptrA
,
torch
::
Tensor
colA
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueA
,
torch
::
Tensor
rowptrB
,
torch
::
Tensor
colB
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueB
,
int64_t
K
);
csrc/spspmm.cpp
deleted
100644 → 0
View file @
955b1cf3
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/spspmm_cpu.h"
#ifdef WITH_CUDA
#include "cuda/spspmm_cuda.h"
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC
PyInit__spspmm_cuda
(
void
)
{
return
NULL
;
}
#else
PyMODINIT_FUNC
PyInit__spspmm_cpu
(
void
)
{
return
NULL
;
}
#endif
#endif
#endif
SPARSE_API
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
spspmm_sum
(
torch
::
Tensor
rowptrA
,
torch
::
Tensor
colA
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueA
,
torch
::
Tensor
rowptrB
,
torch
::
Tensor
colB
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueB
,
int64_t
K
)
{
if
(
rowptrA
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
spspmm_cuda
(
rowptrA
,
colA
,
optional_valueA
,
rowptrB
,
colB
,
optional_valueB
,
K
,
"sum"
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
spspmm_cpu
(
rowptrA
,
colA
,
optional_valueA
,
rowptrB
,
colB
,
optional_valueB
,
K
,
"sum"
);
}
}
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
"torch_sparse::spspmm_sum"
,
&
spspmm_sum
);
setup.py
View file @
cb53126c
...
@@ -15,7 +15,7 @@ from torch.utils.cpp_extension import (
...
@@ -15,7 +15,7 @@ from torch.utils.cpp_extension import (
CUDAExtension
,
CUDAExtension
,
)
)
__version__
=
'0.6.1
5
'
__version__
=
'0.6.1
6
'
URL
=
'https://github.com/rusty1s/pytorch_sparse'
URL
=
'https://github.com/rusty1s/pytorch_sparse'
WITH_CUDA
=
False
WITH_CUDA
=
False
...
@@ -64,7 +64,7 @@ def get_extensions():
...
@@ -64,7 +64,7 @@ def get_extensions():
define_macros
+=
[(
'MTMETIS_64BIT_PARTITIONS'
,
None
)]
define_macros
+=
[(
'MTMETIS_64BIT_PARTITIONS'
,
None
)]
libraries
+=
[
'mtmetis'
,
'wildriver'
]
libraries
+=
[
'mtmetis'
,
'wildriver'
]
extra_compile_args
=
{
'cxx'
:
[
'-O
2
'
]}
extra_compile_args
=
{
'cxx'
:
[
'-O
3
'
]}
if
not
os
.
name
==
'nt'
:
# Not on Windows:
if
not
os
.
name
==
'nt'
:
# Not on Windows:
extra_compile_args
[
'cxx'
]
+=
[
'-Wno-sign-compare'
]
extra_compile_args
[
'cxx'
]
+=
[
'-Wno-sign-compare'
]
extra_link_args
=
[]
if
WITH_SYMBOLS
else
[
'-s'
]
extra_link_args
=
[]
if
WITH_SYMBOLS
else
[
'-s'
]
...
@@ -89,8 +89,7 @@ def get_extensions():
...
@@ -89,8 +89,7 @@ def get_extensions():
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
+=
[
'-O2'
]
nvcc_flags
+=
[
'-O3'
]
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
if
torch
.
version
.
hip
:
if
torch
.
version
.
hip
:
# USE_ROCM was added to later versions of PyTorch
# USE_ROCM was added to later versions of PyTorch
# Define here to support older PyTorch versions as well:
# Define here to support older PyTorch versions as well:
...
@@ -98,17 +97,7 @@ def get_extensions():
...
@@ -98,17 +97,7 @@ def get_extensions():
undef_macros
+=
[
'__HIP_NO_HALF_CONVERSIONS__'
]
undef_macros
+=
[
'__HIP_NO_HALF_CONVERSIONS__'
]
else
:
else
:
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
]
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
]
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
if
torch
.
version
.
hip
:
if
sys
.
platform
==
'win32'
:
extra_link_args
+=
[
'hipsparse.lib'
]
else
:
extra_link_args
+=
[
'-lhipsparse'
,
'-l'
,
'hipsparse'
]
else
:
if
sys
.
platform
==
'win32'
:
extra_link_args
+=
[
'cusparse.lib'
]
else
:
extra_link_args
+=
[
'-lcusparse'
,
'-l'
,
'cusparse'
]
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
sources
=
[
main
]
sources
=
[
main
]
...
...
test/test_matmul.py
View file @
cb53126c
...
@@ -4,7 +4,7 @@ import pytest
...
@@ -4,7 +4,7 @@ import pytest
import
torch
import
torch
import
torch_scatter
import
torch_scatter
from
torch_sparse.matmul
import
matmul
from
torch_sparse.matmul
import
matmul
,
spspmm
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.testing
import
devices
,
grad_dtypes
,
reductions
from
torch_sparse.testing
import
devices
,
grad_dtypes
,
reductions
...
@@ -53,7 +53,7 @@ def test_spmm(dtype, device, reduce):
...
@@ -53,7 +53,7 @@ def test_spmm(dtype, device, reduce):
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_spspmm
(
dtype
,
device
):
def
test_spspmm
(
dtype
,
device
):
if
d
evice
==
torch
.
device
(
'cuda:0'
)
and
dtype
==
torch
.
bfloat16
:
if
d
type
in
{
torch
.
half
,
torch
.
bfloat16
}
:
return
# Not yet implemented.
return
# Not yet implemented.
src
=
torch
.
tensor
([[
1
,
0
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
1
]],
dtype
=
dtype
,
src
=
torch
.
tensor
([[
1
,
0
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
1
]],
dtype
=
dtype
,
...
@@ -75,3 +75,5 @@ def test_spspmm(dtype, device):
...
@@ -75,3 +75,5 @@ def test_spspmm(dtype, device):
rowptr
,
col
,
value
=
out
.
csr
()
rowptr
,
col
,
value
=
out
.
csr
()
assert
rowptr
.
tolist
()
==
[
0
,
1
,
2
,
3
]
assert
rowptr
.
tolist
()
==
[
0
,
1
,
2
,
3
]
assert
col
.
tolist
()
==
[
0
,
1
,
2
]
assert
col
.
tolist
()
==
[
0
,
1
,
2
]
torch
.
jit
.
script
(
spspmm
)
test/test_spspmm.py
View file @
cb53126c
...
@@ -9,7 +9,7 @@ from torch_sparse.testing import devices, grad_dtypes, tensor
...
@@ -9,7 +9,7 @@ from torch_sparse.testing import devices, grad_dtypes, tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_spspmm
(
dtype
,
device
):
def
test_spspmm
(
dtype
,
device
):
if
d
evice
==
torch
.
device
(
'cuda:0'
)
and
dtype
==
torch
.
bfloat16
:
if
d
type
in
{
torch
.
half
,
torch
.
bfloat16
}
:
return
# Not yet implemented.
return
# Not yet implemented.
indexA
=
torch
.
tensor
([[
0
,
0
,
1
,
2
,
2
],
[
1
,
2
,
0
,
0
,
1
]],
device
=
device
)
indexA
=
torch
.
tensor
([[
0
,
0
,
1
,
2
,
2
],
[
1
,
2
,
0
,
0
,
1
]],
device
=
device
)
...
@@ -24,7 +24,7 @@ def test_spspmm(dtype, device):
...
@@ -24,7 +24,7 @@ def test_spspmm(dtype, device):
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_sparse_tensor_spspmm
(
dtype
,
device
):
def
test_sparse_tensor_spspmm
(
dtype
,
device
):
if
d
evice
==
torch
.
device
(
'cuda:0'
)
and
dtype
==
torch
.
bfloat16
:
if
d
type
in
{
torch
.
half
,
torch
.
bfloat16
}
:
return
# Not yet implemented.
return
# Not yet implemented.
x
=
SparseTensor
(
x
=
SparseTensor
(
...
...
torch_sparse/__init__.py
View file @
cb53126c
...
@@ -3,12 +3,11 @@ import os.path as osp
...
@@ -3,12 +3,11 @@ import os.path as osp
import
torch
import
torch
__version__
=
'0.6.1
5
'
__version__
=
'0.6.1
6
'
for
library
in
[
for
library
in
[
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_spspmm'
,
'_metis'
,
'_rw'
,
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_metis'
,
'_rw'
,
'_saint'
,
'_saint'
,
'_sample'
,
'_ego_sample'
,
'_hgt_sample'
,
'_neighbor_sample'
,
'_sample'
,
'_ego_sample'
,
'_hgt_sample'
,
'_neighbor_sample'
,
'_relabel'
'_relabel'
]:
]:
cuda_spec
=
importlib
.
machinery
.
PathFinder
().
find_spec
(
cuda_spec
=
importlib
.
machinery
.
PathFinder
().
find_spec
(
f
'
{
library
}
_cuda'
,
[
osp
.
dirname
(
__file__
)])
f
'
{
library
}
_cuda'
,
[
osp
.
dirname
(
__file__
)])
...
...
torch_sparse/matmul.py
View file @
cb53126c
from
typing
import
Tuple
from
typing
import
Optional
,
Tuple
import
torch
import
torch
from
torch
import
Tensor
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
...
@@ -90,21 +91,23 @@ def spmm(src: SparseTensor, other: torch.Tensor,
...
@@ -90,21 +91,23 @@ def spmm(src: SparseTensor, other: torch.Tensor,
def
spspmm_sum
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
def
spspmm_sum
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
assert
src
.
sparse_size
(
1
)
==
other
.
sparse_size
(
0
)
A
=
src
.
to_torch_sparse_coo_tensor
()
rowptrA
,
colA
,
valueA
=
src
.
csr
()
B
=
other
.
to_torch_sparse_coo_tensor
()
rowptrB
,
colB
,
valueB
=
other
.
csr
()
C
=
torch
.
sparse
.
mm
(
A
,
B
)
value
=
valueA
if
valueA
is
not
None
else
valueB
edge_index
=
C
.
_indices
()
if
valueA
is
not
None
and
valueA
.
dtype
==
torch
.
half
:
row
,
col
=
edge_index
[
0
],
edge_index
[
1
]
valueA
=
valueA
.
to
(
torch
.
float
)
value
:
Optional
[
Tensor
]
=
None
if
valueB
is
not
None
and
valueB
.
dtype
==
torch
.
half
:
if
src
.
has_value
()
and
other
.
has_value
():
valueB
=
valueB
.
to
(
torch
.
float
)
value
=
C
.
_values
()
M
,
K
=
src
.
sparse_size
(
0
),
other
.
sparse_size
(
1
)
rowptrC
,
colC
,
valueC
=
torch
.
ops
.
torch_sparse
.
spspmm_sum
(
return
SparseTensor
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
K
)
row
=
row
,
if
valueC
is
not
None
and
value
is
not
None
:
col
=
col
,
valueC
=
valueC
.
to
(
value
.
dtype
)
value
=
value
,
return
SparseTensor
(
row
=
None
,
rowptr
=
rowptrC
,
col
=
colC
,
value
=
valueC
,
sparse_sizes
=
(
C
.
size
(
0
),
C
.
size
(
1
)),
sparse_sizes
=
(
M
,
K
),
is_sorted
=
True
)
is_sorted
=
True
,
trust_data
=
True
,
)
def
spspmm_add
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
def
spspmm_add
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment