Unverified Commit cb53126c authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Drop `cusparse` (#302)

* update

* update
parent 955b1cf3
...@@ -13,7 +13,7 @@ jobs: ...@@ -13,7 +13,7 @@ jobs:
# We have trouble building for Windows - drop for now. # We have trouble building for Windows - drop for now.
os: [ubuntu-18.04, macos-10.15] # windows-2019 os: [ubuntu-18.04, macos-10.15] # windows-2019
python-version: ['3.7', '3.8', '3.9', '3.10'] python-version: ['3.7', '3.8', '3.9', '3.10']
torch-version: [1.13.0] # [1.12.0, 1.13.0] torch-version: [1.12.0, 1.13.0]
cuda-version: ['cpu', 'cu102', 'cu113', 'cu116', 'cu117'] cuda-version: ['cpu', 'cu102', 'cu113', 'cu116', 'cu117']
exclude: exclude:
- torch-version: 1.12.0 - torch-version: 1.12.0
...@@ -32,8 +32,6 @@ jobs: ...@@ -32,8 +32,6 @@ jobs:
cuda-version: 'cu117' cuda-version: 'cu117'
- os: windows-2019 - os: windows-2019
cuda-version: 'cu102' cuda-version: 'cu102'
- os: windows-2019 # Complains about CUDA mismatch.
python-version: '3.7'
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
......
cmake_minimum_required(VERSION 3.10) cmake_minimum_required(VERSION 3.10)
project(torchsparse) project(torchsparse)
set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD 14)
set(TORCHSPARSE_VERSION 0.6.15) set(TORCHSPARSE_VERSION 0.6.16)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
option(WITH_CUDA "Enable CUDA support" OFF) option(WITH_CUDA "Enable CUDA support" OFF)
...@@ -34,9 +34,6 @@ endif() ...@@ -34,9 +34,6 @@ endif()
add_library(${PROJECT_NAME} SHARED ${OPERATOR_SOURCES}) add_library(${PROJECT_NAME} SHARED ${OPERATOR_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES}) target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
if (WITH_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE ${CUDA_cusparse_LIBRARY})
endif()
if (WITH_PYTHON) if (WITH_PYTHON)
target_link_libraries(${PROJECT_NAME} PRIVATE Python3::Python) target_link_libraries(${PROJECT_NAME} PRIVATE Python3::Python)
endif() endif()
...@@ -95,7 +92,6 @@ install(FILES ...@@ -95,7 +92,6 @@ install(FILES
csrc/cpu/saint_cpu.h csrc/cpu/saint_cpu.h
csrc/cpu/sample_cpu.h csrc/cpu/sample_cpu.h
csrc/cpu/spmm_cpu.h csrc/cpu/spmm_cpu.h
csrc/cpu/spspmm_cpu.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cpu) DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cpu)
if(WITH_CUDA) if(WITH_CUDA)
install(FILES install(FILES
...@@ -103,7 +99,6 @@ if(WITH_CUDA) ...@@ -103,7 +99,6 @@ if(WITH_CUDA)
csrc/cuda/diag_cuda.h csrc/cuda/diag_cuda.h
csrc/cuda/rw_cuda.h csrc/cuda/rw_cuda.h
csrc/cuda/spmm_cuda.h csrc/cuda/spmm_cuda.h
csrc/cuda/spspmm_cuda.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cuda) DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cuda)
endif() endif()
......
package: package:
name: pytorch-sparse name: pytorch-sparse
version: 0.6.15 version: 0.6.16
source: source:
path: ../.. path: ../..
......
#include "spspmm_cpu.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce) {
CHECK_CPU(rowptrA);
CHECK_CPU(colA);
if (optional_valueA.has_value())
CHECK_CPU(optional_valueA.value());
CHECK_CPU(rowptrB);
CHECK_CPU(colB);
if (optional_valueB.has_value())
CHECK_CPU(optional_valueB.value());
CHECK_INPUT(rowptrA.dim() == 1);
CHECK_INPUT(colA.dim() == 1);
if (optional_valueA.has_value()) {
CHECK_INPUT(optional_valueA.value().dim() == 1);
CHECK_INPUT(optional_valueA.value().size(0) == colA.size(0));
}
CHECK_INPUT(rowptrB.dim() == 1);
CHECK_INPUT(colB.dim() == 1);
if (optional_valueB.has_value()) {
CHECK_INPUT(optional_valueB.value().dim() == 1);
CHECK_INPUT(optional_valueB.value().size(0) == colB.size(0));
}
if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA =
torch::ones({colA.numel()}, optional_valueB.value().options());
if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB =
torch::ones({colB.numel()}, optional_valueA.value().options());
auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value())
scalar_type = optional_valueA.value().scalar_type();
auto rowptrA_data = rowptrA.data_ptr<int64_t>();
auto colA_data = colA.data_ptr<int64_t>();
auto rowptrB_data = rowptrB.data_ptr<int64_t>();
auto colB_data = colB.data_ptr<int64_t>();
auto rowptrC = torch::empty_like(rowptrA);
auto rowptrC_data = rowptrC.data_ptr<int64_t>();
rowptrC_data[0] = 0;
torch::Tensor colC;
torch::optional<torch::Tensor> optional_valueC = torch::nullopt;
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, scalar_type, "spspmm", [&] {
AT_DISPATCH_HAS_VALUE(optional_valueA, [&] {
scalar_t *valA_data = nullptr, *valB_data = nullptr;
if (HAS_VALUE) {
valA_data = optional_valueA.value().data_ptr<scalar_t>();
valB_data = optional_valueB.value().data_ptr<scalar_t>();
}
int64_t nnz = 0, cA, cB;
std::vector<scalar_t> tmp_vals(K, 0);
std::vector<int64_t> cols;
std::vector<scalar_t> vals;
for (auto rA = 0; rA < rowptrA.numel() - 1; rA++) {
for (auto eA = rowptrA_data[rA]; eA < rowptrA_data[rA + 1]; eA++) {
cA = colA_data[eA];
for (auto eB = rowptrB_data[cA]; eB < rowptrB_data[cA + 1]; eB++) {
cB = colB_data[eB];
if (HAS_VALUE)
tmp_vals[cB] += valA_data[eA] * valB_data[eB];
else
tmp_vals[cB] += 1;
}
}
for (auto k = 0; k < K; k++) {
if (tmp_vals[k] != 0) {
cols.push_back(k);
if (HAS_VALUE)
vals.push_back(tmp_vals[k]);
nnz++;
}
tmp_vals[k] = (scalar_t)0;
}
rowptrC_data[rA + 1] = nnz;
}
colC = torch::from_blob(cols.data(), {nnz}, colA.options()).clone();
if (HAS_VALUE) {
optional_valueC = torch::from_blob(vals.data(), {nnz},
optional_valueA.value().options());
optional_valueC = optional_valueC.value().clone();
}
});
});
return std::make_tuple(rowptrC, colC, optional_valueC);
}
#pragma once
#include "../extensions.h"
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce);
#include "spspmm_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include <cusparse.h>
#include "utils.cuh"
#define AT_DISPATCH_CUSPARSE_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case torch::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseScsrgemm2_bufferSizeExt; \
const auto &cusparsecsrgemm2 = cusparseScsrgemm2; \
return __VA_ARGS__(); \
} \
case torch::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseDcsrgemm2_bufferSizeExt; \
const auto &cusparsecsrgemm2 = cusparseDcsrgemm2; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce) {
CHECK_CUDA(rowptrA);
CHECK_CUDA(colA);
if (optional_valueA.has_value())
CHECK_CUDA(optional_valueA.value());
CHECK_CUDA(rowptrB);
CHECK_CUDA(colB);
if (optional_valueB.has_value())
CHECK_CUDA(optional_valueB.value());
cudaSetDevice(rowptrA.get_device());
CHECK_INPUT(rowptrA.dim() == 1);
CHECK_INPUT(colA.dim() == 1);
if (optional_valueA.has_value()) {
CHECK_INPUT(optional_valueA.value().dim() == 1);
CHECK_INPUT(optional_valueA.value().size(0) == colA.size(0));
}
CHECK_INPUT(rowptrB.dim() == 1);
CHECK_INPUT(colB.dim() == 1);
if (optional_valueB.has_value()) {
CHECK_INPUT(optional_valueB.value().dim() == 1);
CHECK_INPUT(optional_valueB.value().size(0) == colB.size(0));
}
if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA =
torch::ones({colA.numel()}, optional_valueB.value().options());
if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB =
torch::ones({colB.numel()}, optional_valueA.value().options());
auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value())
scalar_type = optional_valueA.value().scalar_type();
auto handle = at::cuda::getCurrentCUDASparseHandle();
cusparseMatDescr_t descr;
cusparseCreateMatDescr(&descr);
rowptrA = rowptrA.toType(torch::kInt);
colA = colA.toType(torch::kInt);
rowptrB = rowptrB.toType(torch::kInt);
colB = colB.toType(torch::kInt);
int64_t M = rowptrA.numel() - 1, N = rowptrB.numel() - 1;
auto rowptrA_data = rowptrA.data_ptr<int>();
auto colA_data = colA.data_ptr<int>();
auto rowptrB_data = rowptrB.data_ptr<int>();
auto colB_data = colB.data_ptr<int>();
torch::Tensor rowptrC, colC;
torch::optional<torch::Tensor> optional_valueC = torch::nullopt;
int nnzC;
int *nnzTotalDevHostPtr = &nnzC;
// Step 1: Create an opaque structure.
csrgemm2Info_t info = NULL;
cusparseCreateCsrgemm2Info(&info);
// Step 2: Allocate buffer for `csrgemm2Nnz` and `csrgemm2`.
size_t bufferSize;
AT_DISPATCH_CUSPARSE_TYPES(scalar_type, [&] {
scalar_t alpha = (scalar_t)1.0;
cusparsecsrgemm2_bufferSizeExt(handle, M, N, K, &alpha, descr, colA.numel(),
rowptrA_data, colA_data, descr, colB.numel(),
rowptrB_data, colB_data, NULL, descr, 0,
NULL, NULL, info, &bufferSize);
void *buffer = NULL;
cudaMalloc(&buffer, bufferSize);
// Step 3: Compute CSR row pointer.
rowptrC = torch::empty({M + 1}, rowptrA.options());
auto rowptrC_data = rowptrC.data_ptr<int>();
cusparseXcsrgemm2Nnz(handle, M, N, K, descr, colA.numel(), rowptrA_data,
colA_data, descr, colB.numel(), rowptrB_data,
colB_data, descr, 0, NULL, NULL, descr, rowptrC_data,
nnzTotalDevHostPtr, info, buffer);
// Step 4: Compute CSR entries.
colC = torch::empty({nnzC}, rowptrC.options());
auto colC_data = colC.data_ptr<int>();
if (optional_valueA.has_value())
optional_valueC = torch::empty({nnzC}, optional_valueA.value().options());
scalar_t *valA_data = NULL, *valB_data = NULL, *valC_data = NULL;
if (optional_valueA.has_value()) {
valA_data = optional_valueA.value().data_ptr<scalar_t>();
valB_data = optional_valueB.value().data_ptr<scalar_t>();
valC_data = optional_valueC.value().data_ptr<scalar_t>();
}
cusparsecsrgemm2(handle, M, N, K, &alpha, descr, colA.numel(), valA_data,
rowptrA_data, colA_data, descr, colB.numel(), valB_data,
rowptrB_data, colB_data, NULL, descr, 0, NULL, NULL, NULL,
descr, valC_data, rowptrC_data, colC_data, info, buffer);
cudaFree(buffer);
});
// Step 5: Destroy the opaque structure.
cusparseDestroyCsrgemm2Info(info);
rowptrC = rowptrC.toType(torch::kLong);
colC = colC.toType(torch::kLong);
return std::make_tuple(rowptrC, colC, optional_valueC);
}
#pragma once
#include "../extensions.h"
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce);
...@@ -74,10 +74,3 @@ spmm_min(torch::Tensor rowptr, torch::Tensor col, ...@@ -74,10 +74,3 @@ spmm_min(torch::Tensor rowptr, torch::Tensor col,
SPARSE_API std::tuple<torch::Tensor, torch::Tensor> SPARSE_API std::tuple<torch::Tensor, torch::Tensor>
spmm_max(torch::Tensor rowptr, torch::Tensor col, spmm_max(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value, torch::Tensor mat); torch::optional<torch::Tensor> opt_value, torch::Tensor mat);
SPARSE_API
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_sum(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K);
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/spspmm_cpu.h"
#ifdef WITH_CUDA
#include "cuda/spspmm_cuda.h"
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__spspmm_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__spspmm_cpu(void) { return NULL; }
#endif
#endif
#endif
SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_sum(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K) {
if (rowptrA.device().is_cuda()) {
#ifdef WITH_CUDA
return spspmm_cuda(rowptrA, colA, optional_valueA, rowptrB, colB,
optional_valueB, K, "sum");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spspmm_cpu(rowptrA, colA, optional_valueA, rowptrB, colB,
optional_valueB, K, "sum");
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::spspmm_sum", &spspmm_sum);
...@@ -15,7 +15,7 @@ from torch.utils.cpp_extension import ( ...@@ -15,7 +15,7 @@ from torch.utils.cpp_extension import (
CUDAExtension, CUDAExtension,
) )
__version__ = '0.6.15' __version__ = '0.6.16'
URL = 'https://github.com/rusty1s/pytorch_sparse' URL = 'https://github.com/rusty1s/pytorch_sparse'
WITH_CUDA = False WITH_CUDA = False
...@@ -64,7 +64,7 @@ def get_extensions(): ...@@ -64,7 +64,7 @@ def get_extensions():
define_macros += [('MTMETIS_64BIT_PARTITIONS', None)] define_macros += [('MTMETIS_64BIT_PARTITIONS', None)]
libraries += ['mtmetis', 'wildriver'] libraries += ['mtmetis', 'wildriver']
extra_compile_args = {'cxx': ['-O2']} extra_compile_args = {'cxx': ['-O3']}
if not os.name == 'nt': # Not on Windows: if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare'] extra_compile_args['cxx'] += ['-Wno-sign-compare']
extra_link_args = [] if WITH_SYMBOLS else ['-s'] extra_link_args = [] if WITH_SYMBOLS else ['-s']
...@@ -89,8 +89,7 @@ def get_extensions(): ...@@ -89,8 +89,7 @@ def get_extensions():
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-O2'] nvcc_flags += ['-O3']
extra_compile_args['nvcc'] = nvcc_flags
if torch.version.hip: if torch.version.hip:
# USE_ROCM was added to later versions of PyTorch # USE_ROCM was added to later versions of PyTorch
# Define here to support older PyTorch versions as well: # Define here to support older PyTorch versions as well:
...@@ -98,17 +97,7 @@ def get_extensions(): ...@@ -98,17 +97,7 @@ def get_extensions():
undef_macros += ['__HIP_NO_HALF_CONVERSIONS__'] undef_macros += ['__HIP_NO_HALF_CONVERSIONS__']
else: else:
nvcc_flags += ['--expt-relaxed-constexpr'] nvcc_flags += ['--expt-relaxed-constexpr']
extra_compile_args['nvcc'] = nvcc_flags
if torch.version.hip:
if sys.platform == 'win32':
extra_link_args += ['hipsparse.lib']
else:
extra_link_args += ['-lhipsparse', '-l', 'hipsparse']
else:
if sys.platform == 'win32':
extra_link_args += ['cusparse.lib']
else:
extra_link_args += ['-lcusparse', '-l', 'cusparse']
name = main.split(os.sep)[-1][:-4] name = main.split(os.sep)[-1][:-4]
sources = [main] sources = [main]
......
...@@ -4,7 +4,7 @@ import pytest ...@@ -4,7 +4,7 @@ import pytest
import torch import torch
import torch_scatter import torch_scatter
from torch_sparse.matmul import matmul from torch_sparse.matmul import matmul, spspmm
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices, grad_dtypes, reductions from torch_sparse.testing import devices, grad_dtypes, reductions
...@@ -53,7 +53,7 @@ def test_spmm(dtype, device, reduce): ...@@ -53,7 +53,7 @@ def test_spmm(dtype, device, reduce):
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device): def test_spspmm(dtype, device):
if device == torch.device('cuda:0') and dtype == torch.bfloat16: if dtype in {torch.half, torch.bfloat16}:
return # Not yet implemented. return # Not yet implemented.
src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype, src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
...@@ -75,3 +75,5 @@ def test_spspmm(dtype, device): ...@@ -75,3 +75,5 @@ def test_spspmm(dtype, device):
rowptr, col, value = out.csr() rowptr, col, value = out.csr()
assert rowptr.tolist() == [0, 1, 2, 3] assert rowptr.tolist() == [0, 1, 2, 3]
assert col.tolist() == [0, 1, 2] assert col.tolist() == [0, 1, 2]
torch.jit.script(spspmm)
...@@ -9,7 +9,7 @@ from torch_sparse.testing import devices, grad_dtypes, tensor ...@@ -9,7 +9,7 @@ from torch_sparse.testing import devices, grad_dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device): def test_spspmm(dtype, device):
if device == torch.device('cuda:0') and dtype == torch.bfloat16: if dtype in {torch.half, torch.bfloat16}:
return # Not yet implemented. return # Not yet implemented.
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device) indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
...@@ -24,7 +24,7 @@ def test_spspmm(dtype, device): ...@@ -24,7 +24,7 @@ def test_spspmm(dtype, device):
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_sparse_tensor_spspmm(dtype, device): def test_sparse_tensor_spspmm(dtype, device):
if device == torch.device('cuda:0') and dtype == torch.bfloat16: if dtype in {torch.half, torch.bfloat16}:
return # Not yet implemented. return # Not yet implemented.
x = SparseTensor( x = SparseTensor(
......
...@@ -3,12 +3,11 @@ import os.path as osp ...@@ -3,12 +3,11 @@ import os.path as osp
import torch import torch
__version__ = '0.6.15' __version__ = '0.6.16'
for library in [ for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw', '_version', '_convert', '_diag', '_spmm', '_metis', '_rw', '_saint',
'_saint', '_sample', '_ego_sample', '_hgt_sample', '_neighbor_sample', '_sample', '_ego_sample', '_hgt_sample', '_neighbor_sample', '_relabel'
'_relabel'
]: ]:
cuda_spec = importlib.machinery.PathFinder().find_spec( cuda_spec = importlib.machinery.PathFinder().find_spec(
f'{library}_cuda', [osp.dirname(__file__)]) f'{library}_cuda', [osp.dirname(__file__)])
......
from typing import Tuple from typing import Optional, Tuple
import torch import torch
from torch import Tensor
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
...@@ -90,21 +91,23 @@ def spmm(src: SparseTensor, other: torch.Tensor, ...@@ -90,21 +91,23 @@ def spmm(src: SparseTensor, other: torch.Tensor,
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor: def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
assert src.sparse_size(1) == other.sparse_size(0) A = src.to_torch_sparse_coo_tensor()
rowptrA, colA, valueA = src.csr() B = other.to_torch_sparse_coo_tensor()
rowptrB, colB, valueB = other.csr() C = torch.sparse.mm(A, B)
value = valueA if valueA is not None else valueB edge_index = C._indices()
if valueA is not None and valueA.dtype == torch.half: row, col = edge_index[0], edge_index[1]
valueA = valueA.to(torch.float) value: Optional[Tensor] = None
if valueB is not None and valueB.dtype == torch.half: if src.has_value() and other.has_value():
valueB = valueB.to(torch.float) value = C._values()
M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum( return SparseTensor(
rowptrA, colA, valueA, rowptrB, colB, valueB, K) row=row,
if valueC is not None and value is not None: col=col,
valueC = valueC.to(value.dtype) value=value,
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC, sparse_sizes=(C.size(0), C.size(1)),
sparse_sizes=(M, K), is_sorted=True) is_sorted=True,
trust_data=True,
)
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor: def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment