"examples/vscode:/vscode.git/clone" did not exist on "6fedb29f1113e734c47360a04a2d52312a1dd7bc"
Commit 522de767 authored by rusty1s's avatar rusty1s
Browse files

pytorch 1.0.0 update

parent e58d83e7
#pragma once #pragma once
#include <torch/torch.h> #include <torch/extension.h>
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \ #define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
[&] { \ [&] { \
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
auto TENSOR3##_stride = TENSOR3.stride(DIM); \ auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\ \
auto dims = TENSOR1.dim(); \ auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(dims, torch::CPU(at::kLong)); \ auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
auto counter = zeros.data<int64_t>(); \ auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \ bool has_finished = false; \
\ \
...@@ -76,7 +76,7 @@ ...@@ -76,7 +76,7 @@
auto TENSOR4##_stride = TENSOR4.stride(DIM); \ auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\ \
auto dims = TENSOR1.dim(); \ auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(dims, torch::CPU(at::kLong)); \ auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
auto counter = zeros.data<int64_t>(); \ auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \ bool has_finished = false; \
\ \
......
#include <torch/torch.h> #include <torch/extension.h>
#include "dim_apply.h" #include "dim_apply.h"
......
...@@ -15,7 +15,7 @@ if CUDA_HOME is not None: ...@@ -15,7 +15,7 @@ if CUDA_HOME is not None:
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu']) ['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
] ]
__version__ = '1.0.4' __version__ = '1.0.5'
url = 'https://github.com/rusty1s/pytorch_scatter' url = 'https://github.com/rusty1s/pytorch_scatter'
install_requires = [] install_requires = []
......
...@@ -7,7 +7,7 @@ from .std import scatter_std ...@@ -7,7 +7,7 @@ from .std import scatter_std
from .max import scatter_max from .max import scatter_max
from .min import scatter_min from .min import scatter_min
__version__ = '1.0.4' __version__ = '1.0.5'
__all__ = [ __all__ = [
'scatter_add', 'scatter_add',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment