Commit 1c4fdfe2 authored by rusty1s's avatar rusty1s
Browse files

pytorch 1.3 support

parent 573ad113
...@@ -17,7 +17,7 @@ before_install: ...@@ -17,7 +17,7 @@ before_install:
- export CXX="g++-4.9" - export CXX="g++-4.9"
install: install:
- pip install numpy - pip install numpy
- pip install -q torch -f https://download.pytorch.org/whl/nightly/cpu/torch.html - pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- pip install pycodestyle - pip install pycodestyle
- pip install flake8 - pip install flake8
- pip install codecov - pip install codecov
......
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/extension.h> #include <torch/extension.h>
#include "compat.h"
at::Tensor degree(at::Tensor row, int64_t num_nodes) { at::Tensor degree(at::Tensor row, int64_t num_nodes) {
auto zero = at::zeros(num_nodes, row.options()); auto zero = at::zeros(num_nodes, row.options());
auto one = at::ones(row.size(0), row.options()); auto one = at::ones(row.size(0), row.options());
...@@ -18,23 +20,23 @@ at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA, ...@@ -18,23 +20,23 @@ at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
at::Tensor indexB, at::Tensor valueB, size_t rowA_max, at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
size_t rowB_max) { size_t rowB_max) {
int64_t *index_data = index.data<int64_t>(); int64_t *index_data = index.DATA_PTR<int64_t>();
auto value = at::zeros(index.size(1), valueA.options()); auto value = at::zeros(index.size(1), valueA.options());
at::Tensor rowA, colA; at::Tensor rowA, colA;
std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max); std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
int64_t *rowA_data = rowA.data<int64_t>(); int64_t *rowA_data = rowA.DATA_PTR<int64_t>();
int64_t *colA_data = colA.data<int64_t>(); int64_t *colA_data = colA.DATA_PTR<int64_t>();
at::Tensor rowB, colB; at::Tensor rowB, colB;
std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max); std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
int64_t *rowB_data = rowB.data<int64_t>(); int64_t *rowB_data = rowB.DATA_PTR<int64_t>();
int64_t *colB_data = colB.data<int64_t>(); int64_t *colB_data = colB.DATA_PTR<int64_t>();
AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] { AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
scalar_t *value_data = value.data<scalar_t>(); scalar_t *value_data = value.DATA_PTR<scalar_t>();
scalar_t *valueA_data = valueA.data<scalar_t>(); scalar_t *valueA_data = valueA.DATA_PTR<scalar_t>();
scalar_t *valueB_data = valueB.data<scalar_t>(); scalar_t *valueB_data = valueB.DATA_PTR<scalar_t>();
for (int64_t e = 0; e < value.size(0); e++) { for (int64_t e = 0; e < value.size(0); e++) {
int64_t i = index_data[e], j = index_data[value.size(0) + e]; int64_t i = index_data[e], j = index_data[value.size(0) + e];
......
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <cusparse.h> #include <cusparse.h>
#include "compat.cuh"
#define THREADS 1024 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
...@@ -51,18 +52,18 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB, ...@@ -51,18 +52,18 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
// Convert A to CSR format. // Convert A to CSR format.
auto row_ptrA = at::empty(m + 1, indexA.options()); auto row_ptrA = at::empty(m + 1, indexA.options());
cusparseXcoo2csr(cusparse_handle, indexA[0].data<int>(), nnzA, k, cusparseXcoo2csr(cusparse_handle, indexA[0].DATA_PTR<int>(), nnzA, k,
row_ptrA.data<int>(), CUSPARSE_INDEX_BASE_ZERO); row_ptrA.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto colA = indexA[1]; auto colA = indexA[1];
cudaMemcpy(row_ptrA.data<int>() + m, &nnzA, sizeof(int), cudaMemcpy(row_ptrA.DATA_PTR<int>() + m, &nnzA, sizeof(int),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
// Convert B to CSR format. // Convert B to CSR format.
auto row_ptrB = at::empty(k + 1, indexB.options()); auto row_ptrB = at::empty(k + 1, indexB.options());
cusparseXcoo2csr(cusparse_handle, indexB[0].data<int>(), nnzB, k, cusparseXcoo2csr(cusparse_handle, indexB[0].DATA_PTR<int>(), nnzB, k,
row_ptrB.data<int>(), CUSPARSE_INDEX_BASE_ZERO); row_ptrB.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto colB = indexB[1]; auto colB = indexB[1];
cudaMemcpy(row_ptrB.data<int>() + k, &nnzB, sizeof(int), cudaMemcpy(row_ptrB.DATA_PTR<int>() + k, &nnzB, sizeof(int),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
cusparseMatDescr_t descr = 0; cusparseMatDescr_t descr = 0;
...@@ -74,22 +75,23 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB, ...@@ -74,22 +75,23 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
auto row_ptrC = at::empty(m + 1, indexB.options()); auto row_ptrC = at::empty(m + 1, indexB.options());
cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA, CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
row_ptrA.data<int>(), colA.data<int>(), descr, nnzB, row_ptrA.DATA_PTR<int>(), colA.DATA_PTR<int>(), descr,
row_ptrB.data<int>(), colB.data<int>(), descr, nnzB, row_ptrB.DATA_PTR<int>(), colB.DATA_PTR<int>(),
row_ptrC.data<int>(), &nnzC); descr, row_ptrC.DATA_PTR<int>(), &nnzC);
auto colC = at::empty(nnzC, indexA.options()); auto colC = at::empty(nnzC, indexA.options());
auto valueC = at::empty(nnzC, valueA.options()); auto valueC = at::empty(nnzC, valueA.options());
CSRGEMM(valueC.scalar_type(), cusparse_handle, CSRGEMM(valueC.scalar_type(), cusparse_handle,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, m, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, m,
n, k, descr, nnzA, valueA.data<scalar_t>(), row_ptrA.data<int>(), n, k, descr, nnzA, valueA.DATA_PTR<scalar_t>(),
colA.data<int>(), descr, nnzB, valueB.data<scalar_t>(), row_ptrA.DATA_PTR<int>(), colA.DATA_PTR<int>(), descr, nnzB,
row_ptrB.data<int>(), colB.data<int>(), descr, valueB.DATA_PTR<scalar_t>(), row_ptrB.DATA_PTR<int>(),
valueC.data<scalar_t>(), row_ptrC.data<int>(), colC.data<int>()); colB.DATA_PTR<int>(), descr, valueC.DATA_PTR<scalar_t>(),
row_ptrC.DATA_PTR<int>(), colC.DATA_PTR<int>());
auto rowC = at::empty(nnzC, indexA.options()); auto rowC = at::empty(nnzC, indexA.options());
cusparseXcsr2coo(cusparse_handle, row_ptrC.data<int>(), nnzC, m, cusparseXcsr2coo(cusparse_handle, row_ptrC.DATA_PTR<int>(), nnzC, m,
rowC.data<int>(), CUSPARSE_INDEX_BASE_ZERO); rowC.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto indexC = at::stack({rowC, colC}, 0).toType(at::kLong); auto indexC = at::stack({rowC, colC}, 0).toType(at::kLong);
...@@ -154,9 +156,10 @@ at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA, ...@@ -154,9 +156,10 @@ at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] { AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>( spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
index.data<int64_t>(), value.data<scalar_t>(), rowA.data<int64_t>(), index.DATA_PTR<int64_t>(), value.DATA_PTR<scalar_t>(),
colA.data<int64_t>(), valueA.data<scalar_t>(), rowB.data<int64_t>(), rowA.DATA_PTR<int64_t>(), colA.DATA_PTR<int64_t>(),
colB.data<int64_t>(), valueB.data<scalar_t>(), value.numel()); valueA.DATA_PTR<scalar_t>(), rowB.DATA_PTR<int64_t>(),
colB.DATA_PTR<int64_t>(), valueB.DATA_PTR<scalar_t>(), value.numel());
}); });
return value; return value;
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
...@@ -23,7 +25,7 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) { ...@@ -23,7 +25,7 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
auto mask = at::zeros(src.numel(), src.options().dtype(at::kByte)); auto mask = at::zeros(src.numel(), src.options().dtype(at::kByte));
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] {
unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>( unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
src.data<scalar_t>(), mask.data<uint8_t>(), src.numel()); src.DATA_PTR<scalar_t>(), mask.DATA_PTR<uint8_t>(), src.numel());
}); });
src = src.masked_select(mask); src = src.masked_select(mask);
......
...@@ -3,7 +3,17 @@ from setuptools import setup, find_packages ...@@ -3,7 +3,17 @@ from setuptools import setup, find_packages
import torch import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
ext_modules = [CppExtension('torch_sparse.spspmm_cpu', ['cpu/spspmm.cpp'])] TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
extra_compile_args = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3']
ext_modules = [
CppExtension('torch_sparse.spspmm_cpu', ['cpu/spspmm.cpp'],
extra_compile_args=extra_compile_args)
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if CUDA_HOME is not None: if CUDA_HOME is not None:
...@@ -13,15 +23,16 @@ if CUDA_HOME is not None: ...@@ -13,15 +23,16 @@ if CUDA_HOME is not None:
extra_link_args = ['-lcusparse', '-l', 'cusparse'] extra_link_args = ['-lcusparse', '-l', 'cusparse']
ext_modules += [ ext_modules += [
CUDAExtension( CUDAExtension('torch_sparse.spspmm_cuda',
'torch_sparse.spspmm_cuda',
['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'], ['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'],
extra_link_args=extra_link_args), extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args),
CUDAExtension('torch_sparse.unique_cuda', CUDAExtension('torch_sparse.unique_cuda',
['cuda/unique.cpp', 'cuda/unique_kernel.cu']), ['cuda/unique.cpp', 'cuda/unique_kernel.cu'],
extra_compile_args=extra_compile_args),
] ]
__version__ = '0.4.0' __version__ = '0.4.1'
url = 'https://github.com/rusty1s/pytorch_sparse' url = 'https://github.com/rusty1s/pytorch_sparse'
install_requires = ['scipy'] install_requires = ['scipy']
......
...@@ -5,7 +5,7 @@ from .eye import eye ...@@ -5,7 +5,7 @@ from .eye import eye
from .spmm import spmm from .spmm import spmm
from .spspmm import spspmm from .spspmm import spspmm
__version__ = '0.4.0' __version__ = '0.4.1'
__all__ = [ __all__ = [
'__version__', '__version__',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment