Commit a7671588 authored by rusty1s's avatar rusty1s
Browse files

update to 1.1.0

parent 32cce985
...@@ -36,7 +36,7 @@ All included operations work on varying data types, are implemented both for CPU ...@@ -36,7 +36,7 @@ All included operations work on varying data types, are implemented both for CPU
## Installation ## Installation
Ensure that at least PyTorch 1.0.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*: Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
``` ```
$ python -c "import torch; print(torch.__version__)" $ python -c "import torch; print(torch.__version__)"
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) { int64_t dim) {
int64_t elems_per_row = index.size(dim), i, idx; int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_mul", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul", [&] {
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, { DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
for (i = 0; i < elems_per_row; i++) { for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride]; idx = index_data[i * index_stride];
...@@ -18,7 +18,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -18,7 +18,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) { int64_t dim) {
int64_t elems_per_row = index.size(dim), i, idx; int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_div", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div", [&] {
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, { DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
for (i = 0; i < elems_per_row; i++) { for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride]; idx = index_data[i * index_stride];
...@@ -31,7 +31,7 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -31,7 +31,7 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) { at::Tensor arg, int64_t dim) {
int64_t elems_per_row = index.size(dim), i, idx; int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_max", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max", [&] {
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim, DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
{ {
for (i = 0; i < elems_per_row; i++) { for (i = 0; i < elems_per_row; i++) {
...@@ -48,7 +48,7 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -48,7 +48,7 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) { at::Tensor arg, int64_t dim) {
int64_t elems_per_row = index.size(dim), i, idx; int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_min", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min", [&] {
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim, DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
{ {
for (i = 0; i < elems_per_row; i++) { for (i = 0; i < elems_per_row; i++) {
...@@ -65,7 +65,7 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -65,7 +65,7 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg, void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg,
at::Tensor out, int64_t dim) { at::Tensor out, int64_t dim) {
int64_t elems_per_row = index.size(dim), i, idx; int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(grad.type(), "index_backward", [&] { AT_DISPATCH_ALL_TYPES(grad.scalar_type(), "index_backward", [&] {
DIM_APPLY4(scalar_t, grad, int64_t, index, int64_t, arg, scalar_t, out, dim, DIM_APPLY4(scalar_t, grad, int64_t, index, int64_t, arg, scalar_t, out, dim,
{ {
for (i = 0; i < elems_per_row; i++) { for (i = 0; i < elems_per_row; i++) {
......
...@@ -44,7 +44,7 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src, ...@@ -44,7 +44,7 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) { int64_t dim) {
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_mul_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] {
KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(), KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src), at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index), at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
...@@ -71,7 +71,7 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src, ...@@ -71,7 +71,7 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) { int64_t dim) {
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_div_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] {
KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(), KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src), at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index), at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
...@@ -117,7 +117,7 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src, ...@@ -117,7 +117,7 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) { at::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_max_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src); auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out); auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
...@@ -148,7 +148,7 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src, ...@@ -148,7 +148,7 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) { at::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_min_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src); auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out); auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
...@@ -184,7 +184,7 @@ index_backward_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad, ...@@ -184,7 +184,7 @@ index_backward_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad,
void index_backward_cuda(at::Tensor grad, at::Tensor index, at::Tensor arg, void index_backward_cuda(at::Tensor grad, at::Tensor index, at::Tensor arg,
at::Tensor out, int64_t dim) { at::Tensor out, int64_t dim) {
cudaSetDevice(grad.get_device()); cudaSetDevice(grad.get_device());
AT_DISPATCH_ALL_TYPES(grad.type(), "index_backward_kernel", [&] { AT_DISPATCH_ALL_TYPES(grad.scalar_type(), "index_backward_kernel", [&] {
KERNEL_RUN(index_backward_kernel, index.dim(), index.numel(), KERNEL_RUN(index_backward_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad), at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index), at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
......
...@@ -20,7 +20,7 @@ if CUDA_HOME is not None: ...@@ -20,7 +20,7 @@ if CUDA_HOME is not None:
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu']) ['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
] ]
__version__ = '1.1.2' __version__ = '1.2.0'
url = 'https://github.com/rusty1s/pytorch_scatter' url = 'https://github.com/rusty1s/pytorch_scatter'
install_requires = [] install_requires = []
......
...@@ -7,7 +7,7 @@ from .std import scatter_std ...@@ -7,7 +7,7 @@ from .std import scatter_std
from .max import scatter_max from .max import scatter_max
from .min import scatter_min from .min import scatter_min
__version__ = '1.1.2' __version__ = '1.2.0'
__all__ = [ __all__ = [
'scatter_add', 'scatter_add',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment