Commit 6884ab18 authored by rusty1s's avatar rusty1s
Browse files

no warnings

parent a49a26d0
...@@ -3,112 +3,116 @@ ...@@ -3,112 +3,116 @@
#include <torch/torch.h> #include <torch/torch.h>
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \ #define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \ [&] { \
auto TENSOR1##_size = TENSOR1.size(DIM); \ TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \ auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\ \
TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \ TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \ auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \ auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\ \
TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \ TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \ auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \ auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\ \
auto dims = TENSOR1.dim(); \ auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(torch::CPU(at::kLong), {dims}); \ auto zeros = at::zeros(dims, torch::CPU(at::kLong)); \
auto counter = zeros.data<int64_t>(); \ auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \ bool has_finished = false; \
\ \
while (!has_finished) { \ while (!has_finished) { \
CODE; \ CODE; \
if (dims == 1) \ if (dims == 1) \
break; \ break; \
\ \
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \ for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \ if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \ if (cur_dim == dims - 1) { \
has_finished = true; \ has_finished = true; \
break; \ break; \
} \
continue; \
} \ } \
continue; \
} \
\ \
counter[cur_dim]++; \ counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \ TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \ TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \ TENSOR3##_data += TENSOR3.stride(cur_dim); \
\ \
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \ if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \ if (cur_dim == dims - 1) { \
has_finished = true; \ has_finished = true; \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \ break; \
} else { \ } \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} \ } \
} }()
#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \ #define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
TENSOR4, DIM, CODE) \ TENSOR4, DIM, CODE) \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \ [&] { \
auto TENSOR1##_size = TENSOR1.size(DIM); \ TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \ auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\ \
TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \ TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \ auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \ auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\ \
TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \ TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \ auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \ auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\ \
TYPE4 *TENSOR4##_data = TENSOR4.data<TYPE4>(); \ TYPE4 *TENSOR4##_data = TENSOR4.data<TYPE4>(); \
auto TENSOR4##_size = TENSOR4.size(DIM); \ auto TENSOR4##_size = TENSOR4.size(DIM); \
auto TENSOR4##_stride = TENSOR4.stride(DIM); \ auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\ \
auto dims = TENSOR1.dim(); \ auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(torch::CPU(at::kLong), {dims}); \ auto zeros = at::zeros(dims, torch::CPU(at::kLong)); \
auto counter = zeros.data<int64_t>(); \ auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \ bool has_finished = false; \
\ \
while (!has_finished) { \ while (!has_finished) { \
CODE; \ CODE; \
if (dims == 1) \ if (dims == 1) \
break; \ break; \
\ \
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \ for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \ if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \ if (cur_dim == dims - 1) { \
has_finished = true; \ has_finished = true; \
break; \ break; \
} \
continue; \
} \ } \
continue; \
} \
\ \
counter[cur_dim]++; \ counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \ TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \ TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \ TENSOR3##_data += TENSOR3.stride(cur_dim); \
TENSOR4##_data += TENSOR4.stride(cur_dim); \ TENSOR4##_data += TENSOR4.stride(cur_dim); \
\ \
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \ if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \ if (cur_dim == dims - 1) { \
has_finished = true; \ has_finished = true; \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
TENSOR4##_data -= counter[cur_dim] * TENSOR4.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \ break; \
} else { \ } \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
TENSOR4##_data -= counter[cur_dim] * TENSOR4.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} \ } \
} }()
...@@ -11,7 +11,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -11,7 +11,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
idx = index_data[i * index_stride]; idx = index_data[i * index_stride];
out_data[idx * out_stride] *= src_data[i * src_stride]; out_data[idx * out_stride] *= src_data[i * src_stride];
} }
}) });
}); });
} }
...@@ -24,7 +24,7 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -24,7 +24,7 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
idx = index_data[i * index_stride]; idx = index_data[i * index_stride];
out_data[idx * out_stride] /= src_data[i * src_stride]; out_data[idx * out_stride] /= src_data[i * src_stride];
} }
}) });
}); });
} }
...@@ -41,7 +41,7 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -41,7 +41,7 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
arg_data[idx * arg_stride] = i; arg_data[idx * arg_stride] = i;
} }
} }
}) });
}); });
} }
...@@ -58,7 +58,7 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -58,7 +58,7 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
arg_data[idx * arg_stride] = i; arg_data[idx * arg_stride] = i;
} }
} }
}) });
}); });
} }
...@@ -74,7 +74,7 @@ void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg, ...@@ -74,7 +74,7 @@ void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg,
out_data[i * out_stride] = grad_data[idx * grad_stride]; out_data[i * out_stride] = grad_data[idx * grad_stride];
} }
} }
}) });
}); });
} }
......
...@@ -4,7 +4,11 @@ from setuptools import setup, find_packages ...@@ -4,7 +4,11 @@ from setuptools import setup, find_packages
import torch.cuda import torch.cuda
from torch.utils.cpp_extension import CppExtension, CUDAExtension from torch.utils.cpp_extension import CppExtension, CUDAExtension
ext_modules = [CppExtension('scatter_cpu', ['cpu/scatter.cpp'])] ext_modules = [
CppExtension(
'scatter_cpu', ['cpu/scatter.cpp'],
extra_compile_args=['-Wno-unused-variable'])
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if torch.cuda.is_available(): if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment