Commit a7671588 authored by rusty1s's avatar rusty1s
Browse files

update to 1.1.0

parent 32cce985
......@@ -36,7 +36,7 @@ All included operations work on varying data types, are implemented both for CPU
## Installation
Ensure that at least PyTorch 1.0.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
```
$ python -c "import torch; print(torch.__version__)"
......
......@@ -5,7 +5,7 @@
void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_mul", [&] {
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul", [&] {
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
......@@ -18,7 +18,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_div", [&] {
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div", [&] {
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
......@@ -31,7 +31,7 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_max", [&] {
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max", [&] {
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
{
for (i = 0; i < elems_per_row; i++) {
......@@ -48,7 +48,7 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_min", [&] {
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min", [&] {
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
{
for (i = 0; i < elems_per_row; i++) {
......@@ -65,7 +65,7 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg,
at::Tensor out, int64_t dim) {
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(grad.type(), "index_backward", [&] {
AT_DISPATCH_ALL_TYPES(grad.scalar_type(), "index_backward", [&] {
DIM_APPLY4(scalar_t, grad, int64_t, index, int64_t, arg, scalar_t, out, dim,
{
for (i = 0; i < elems_per_row; i++) {
......
......@@ -44,7 +44,7 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_mul_kernel", [&] {
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] {
KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
......@@ -71,7 +71,7 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_div_kernel", [&] {
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] {
KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
......@@ -117,7 +117,7 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_max_kernel", [&] {
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
......@@ -148,7 +148,7 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_min_kernel", [&] {
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
......@@ -184,7 +184,7 @@ index_backward_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad,
void index_backward_cuda(at::Tensor grad, at::Tensor index, at::Tensor arg,
at::Tensor out, int64_t dim) {
cudaSetDevice(grad.get_device());
AT_DISPATCH_ALL_TYPES(grad.type(), "index_backward_kernel", [&] {
AT_DISPATCH_ALL_TYPES(grad.scalar_type(), "index_backward_kernel", [&] {
KERNEL_RUN(index_backward_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
......
......@@ -20,7 +20,7 @@ if CUDA_HOME is not None:
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
]
__version__ = '1.1.2'
__version__ = '1.2.0'
url = 'https://github.com/rusty1s/pytorch_scatter'
install_requires = []
......
......@@ -7,7 +7,7 @@ from .std import scatter_std
from .max import scatter_max
from .min import scatter_min
__version__ = '1.1.2'
__version__ = '1.2.0'
__all__ = [
'scatter_add',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment