Commit de528831 authored by rusty1s's avatar rusty1s
Browse files

pytorch 1.1.0 update

parent 9732a518
...@@ -17,9 +17,9 @@ before_install: ...@@ -17,9 +17,9 @@ before_install:
- export CC="gcc-4.9" - export CC="gcc-4.9"
- export CXX="g++-4.9" - export CXX="g++-4.9"
install: install:
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl; fi - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp27-cp27mu-linux_x86_64.whl; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl; fi - if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp35-cp35m-linux_x86_64.whl; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl; fi - if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp36-cp36m-linux_x86_64.whl; fi
- pip install pycodestyle - pip install pycodestyle
- pip install flake8 - pip install flake8
- pip install codecov - pip install codecov
......
...@@ -28,7 +28,7 @@ Note that only `value` comes with autograd support, as `index` is discrete and t ...@@ -28,7 +28,7 @@ Note that only `value` comes with autograd support, as `index` is discrete and t
## Installation ## Installation
Ensure that at least PyTorch 1.0.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*: Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
``` ```
$ python -c "import torch; print(torch.__version__)" $ python -c "import torch; print(torch.__version__)"
......
...@@ -31,7 +31,7 @@ at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA, ...@@ -31,7 +31,7 @@ at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
int64_t *rowB_data = rowB.data<int64_t>(); int64_t *rowB_data = rowB.data<int64_t>();
int64_t *colB_data = colB.data<int64_t>(); int64_t *colB_data = colB.data<int64_t>();
AT_DISPATCH_FLOATING_TYPES(valueA.type(), "spspmm_bw", [&] { AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
scalar_t *value_data = value.data<scalar_t>(); scalar_t *value_data = value.data<scalar_t>();
scalar_t *valueA_data = valueA.data<scalar_t>(); scalar_t *valueA_data = valueA.data<scalar_t>();
scalar_t *valueB_data = valueB.data<scalar_t>(); scalar_t *valueB_data = valueB.data<scalar_t>();
......
...@@ -7,8 +7,10 @@ ...@@ -7,8 +7,10 @@
#define CSRGEMM(TYPE, ...) \ #define CSRGEMM(TYPE, ...) \
[&] { \ [&] { \
const at::Type &the_type = TYPE; \ const auto &the_type = TYPE; \
switch (the_type.scalarType()) { \ (void)the_type; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
case at::ScalarType::Float: { \ case at::ScalarType::Float: { \
using scalar_t = float; \ using scalar_t = float; \
return cusparseScsrgemm(__VA_ARGS__); \ return cusparseScsrgemm(__VA_ARGS__); \
...@@ -18,7 +20,7 @@ ...@@ -18,7 +20,7 @@
return cusparseDcsrgemm(__VA_ARGS__); \ return cusparseDcsrgemm(__VA_ARGS__); \
} \ } \
default: \ default: \
AT_ERROR("Not implemented for '%s'", the_type.toString()); \ AT_ERROR("Not implemented for '", toString(_st), "'"); \
} \ } \
}() }()
...@@ -48,7 +50,7 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB, ...@@ -48,7 +50,7 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
indexB = indexB.toType(at::kInt); indexB = indexB.toType(at::kInt);
// Convert A to CSR format. // Convert A to CSR format.
auto row_ptrA = at::empty(m + 1, indexA.type()); auto row_ptrA = at::empty(m + 1, indexA.options());
cusparseXcoo2csr(cusparse_handle, indexA[0].data<int>(), nnzA, k, cusparseXcoo2csr(cusparse_handle, indexA[0].data<int>(), nnzA, k,
row_ptrA.data<int>(), CUSPARSE_INDEX_BASE_ZERO); row_ptrA.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto colA = indexA[1]; auto colA = indexA[1];
...@@ -56,7 +58,7 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB, ...@@ -56,7 +58,7 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
// Convert B to CSR format. // Convert B to CSR format.
auto row_ptrB = at::empty(k + 1, indexB.type()); auto row_ptrB = at::empty(k + 1, indexB.options());
cusparseXcoo2csr(cusparse_handle, indexB[0].data<int>(), nnzB, k, cusparseXcoo2csr(cusparse_handle, indexB[0].data<int>(), nnzB, k,
row_ptrB.data<int>(), CUSPARSE_INDEX_BASE_ZERO); row_ptrB.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto colB = indexB[1]; auto colB = indexB[1];
...@@ -69,23 +71,23 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB, ...@@ -69,23 +71,23 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO); cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
int nnzC; int nnzC;
auto row_ptrC = at::empty(m + 1, indexB.type()); auto row_ptrC = at::empty(m + 1, indexB.options());
cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA, CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
row_ptrA.data<int>(), colA.data<int>(), descr, nnzB, row_ptrA.data<int>(), colA.data<int>(), descr, nnzB,
row_ptrB.data<int>(), colB.data<int>(), descr, row_ptrB.data<int>(), colB.data<int>(), descr,
row_ptrC.data<int>(), &nnzC); row_ptrC.data<int>(), &nnzC);
auto colC = at::empty(nnzC, indexA.type()); auto colC = at::empty(nnzC, indexA.options());
auto valueC = at::empty(nnzC, valueA.type()); auto valueC = at::empty(nnzC, valueA.options());
CSRGEMM(valueC.type(), cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, CSRGEMM(valueC.scalar_type(), cusparse_handle,
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, m,
valueA.data<scalar_t>(), row_ptrA.data<int>(), colA.data<int>(), n, k, descr, nnzA, valueA.data<scalar_t>(), row_ptrA.data<int>(),
descr, nnzB, valueB.data<scalar_t>(), row_ptrB.data<int>(), colA.data<int>(), descr, nnzB, valueB.data<scalar_t>(),
colB.data<int>(), descr, valueC.data<scalar_t>(), row_ptrB.data<int>(), colB.data<int>(), descr,
row_ptrC.data<int>(), colC.data<int>()); valueC.data<scalar_t>(), row_ptrC.data<int>(), colC.data<int>());
auto rowC = at::empty(nnzC, indexA.type()); auto rowC = at::empty(nnzC, indexA.options());
cusparseXcsr2coo(cusparse_handle, row_ptrC.data<int>(), nnzC, m, cusparseXcsr2coo(cusparse_handle, row_ptrC.data<int>(), nnzC, m,
rowC.data<int>(), CUSPARSE_INDEX_BASE_ZERO); rowC.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
...@@ -150,7 +152,7 @@ at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA, ...@@ -150,7 +152,7 @@ at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
at::Tensor rowB, colB; at::Tensor rowB, colB;
std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max); std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
AT_DISPATCH_FLOATING_TYPES(valueA.type(), "spspmm_bw", [&] { AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>( spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
index.data<int64_t>(), value.data<scalar_t>(), rowA.data<int64_t>(), index.data<int64_t>(), value.data<scalar_t>(), rowA.data<int64_t>(),
colA.data<int64_t>(), valueA.data<scalar_t>(), rowB.data<int64_t>(), colA.data<int64_t>(), valueA.data<scalar_t>(), rowB.data<int64_t>(),
......
...@@ -20,8 +20,8 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) { ...@@ -20,8 +20,8 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
at::Tensor perm; at::Tensor perm;
std::tie(src, perm) = src.sort(); std::tie(src, perm) = src.sort();
auto mask = at::zeros(src.numel(), src.type().toScalarType(at::kByte)); auto mask = at::zeros(src.numel(), src.options().dtype(at::kByte));
AT_DISPATCH_ALL_TYPES(src.type(), "grid_cuda_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] {
unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>( unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
src.data<scalar_t>(), mask.data<uint8_t>(), src.numel()); src.data<scalar_t>(), mask.data<uint8_t>(), src.numel());
}); });
......
...@@ -21,7 +21,7 @@ if CUDA_HOME is not None: ...@@ -21,7 +21,7 @@ if CUDA_HOME is not None:
['cuda/unique.cpp', 'cuda/unique_kernel.cu']), ['cuda/unique.cpp', 'cuda/unique_kernel.cu']),
] ]
__version__ = '0.3.0' __version__ = '0.4.0'
url = 'https://github.com/rusty1s/pytorch_sparse' url = 'https://github.com/rusty1s/pytorch_sparse'
install_requires = ['scipy'] install_requires = ['scipy']
......
...@@ -5,7 +5,7 @@ from .eye import eye ...@@ -5,7 +5,7 @@ from .eye import eye
from .spmm import spmm from .spmm import spmm
from .spspmm import spspmm from .spspmm import spspmm
__version__ = '0.3.0' __version__ = '0.4.0'
__all__ = [ __all__ = [
'__version__', '__version__',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment