Commit de528831 authored by rusty1s's avatar rusty1s
Browse files

pytorch 1.1.0 update

parent 9732a518
......@@ -17,9 +17,9 @@ before_install:
- export CC="gcc-4.9"
- export CXX="g++-4.9"
install:
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl; fi
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp27-cp27mu-linux_x86_64.whl; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp35-cp35m-linux_x86_64.whl; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp36-cp36m-linux_x86_64.whl; fi
- pip install pycodestyle
- pip install flake8
- pip install codecov
......
......@@ -28,7 +28,7 @@ Note that only `value` comes with autograd support, as `index` is discrete and t
## Installation
Ensure that at least PyTorch 1.0.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
```
$ python -c "import torch; print(torch.__version__)"
......
......@@ -31,7 +31,7 @@ at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
int64_t *rowB_data = rowB.data<int64_t>();
int64_t *colB_data = colB.data<int64_t>();
AT_DISPATCH_FLOATING_TYPES(valueA.type(), "spspmm_bw", [&] {
AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
scalar_t *value_data = value.data<scalar_t>();
scalar_t *valueA_data = valueA.data<scalar_t>();
scalar_t *valueB_data = valueB.data<scalar_t>();
......
......@@ -7,8 +7,10 @@
#define CSRGEMM(TYPE, ...) \
[&] { \
const at::Type &the_type = TYPE; \
switch (the_type.scalarType()) { \
const auto &the_type = TYPE; \
(void)the_type; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
return cusparseScsrgemm(__VA_ARGS__); \
......@@ -18,7 +20,7 @@
return cusparseDcsrgemm(__VA_ARGS__); \
} \
default: \
AT_ERROR("Not implemented for '%s'", the_type.toString()); \
AT_ERROR("Not implemented for '", toString(_st), "'"); \
} \
}()
......@@ -48,7 +50,7 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
indexB = indexB.toType(at::kInt);
// Convert A to CSR format.
auto row_ptrA = at::empty(m + 1, indexA.type());
auto row_ptrA = at::empty(m + 1, indexA.options());
cusparseXcoo2csr(cusparse_handle, indexA[0].data<int>(), nnzA, k,
row_ptrA.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto colA = indexA[1];
......@@ -56,7 +58,7 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
cudaMemcpyHostToDevice);
// Convert B to CSR format.
auto row_ptrB = at::empty(k + 1, indexB.type());
auto row_ptrB = at::empty(k + 1, indexB.options());
cusparseXcoo2csr(cusparse_handle, indexB[0].data<int>(), nnzB, k,
row_ptrB.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto colB = indexB[1];
......@@ -69,23 +71,23 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
int nnzC;
auto row_ptrC = at::empty(m + 1, indexB.type());
auto row_ptrC = at::empty(m + 1, indexB.options());
cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
row_ptrA.data<int>(), colA.data<int>(), descr, nnzB,
row_ptrB.data<int>(), colB.data<int>(), descr,
row_ptrC.data<int>(), &nnzC);
auto colC = at::empty(nnzC, indexA.type());
auto valueC = at::empty(nnzC, valueA.type());
auto colC = at::empty(nnzC, indexA.options());
auto valueC = at::empty(nnzC, valueA.options());
CSRGEMM(valueC.type(), cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
valueA.data<scalar_t>(), row_ptrA.data<int>(), colA.data<int>(),
descr, nnzB, valueB.data<scalar_t>(), row_ptrB.data<int>(),
colB.data<int>(), descr, valueC.data<scalar_t>(),
row_ptrC.data<int>(), colC.data<int>());
CSRGEMM(valueC.scalar_type(), cusparse_handle,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, m,
n, k, descr, nnzA, valueA.data<scalar_t>(), row_ptrA.data<int>(),
colA.data<int>(), descr, nnzB, valueB.data<scalar_t>(),
row_ptrB.data<int>(), colB.data<int>(), descr,
valueC.data<scalar_t>(), row_ptrC.data<int>(), colC.data<int>());
auto rowC = at::empty(nnzC, indexA.type());
auto rowC = at::empty(nnzC, indexA.options());
cusparseXcsr2coo(cusparse_handle, row_ptrC.data<int>(), nnzC, m,
rowC.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
......@@ -150,7 +152,7 @@ at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
at::Tensor rowB, colB;
std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
AT_DISPATCH_FLOATING_TYPES(valueA.type(), "spspmm_bw", [&] {
AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
index.data<int64_t>(), value.data<scalar_t>(), rowA.data<int64_t>(),
colA.data<int64_t>(), valueA.data<scalar_t>(), rowB.data<int64_t>(),
......
......@@ -20,8 +20,8 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
at::Tensor perm;
std::tie(src, perm) = src.sort();
auto mask = at::zeros(src.numel(), src.type().toScalarType(at::kByte));
AT_DISPATCH_ALL_TYPES(src.type(), "grid_cuda_kernel", [&] {
auto mask = at::zeros(src.numel(), src.options().dtype(at::kByte));
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] {
unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
src.data<scalar_t>(), mask.data<uint8_t>(), src.numel());
});
......
......@@ -21,7 +21,7 @@ if CUDA_HOME is not None:
['cuda/unique.cpp', 'cuda/unique_kernel.cu']),
]
__version__ = '0.3.0'
__version__ = '0.4.0'
url = 'https://github.com/rusty1s/pytorch_sparse'
install_requires = ['scipy']
......
......@@ -5,7 +5,7 @@ from .eye import eye
from .spmm import spmm
from .spspmm import spspmm
__version__ = '0.3.0'
__version__ = '0.4.0'
__all__ = [
'__version__',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment