Commit 6884ab18 authored by rusty1s's avatar rusty1s
Browse files

no warnings

parent a49a26d0
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <torch/torch.h> #include <torch/torch.h>
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \ #define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \ TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \ auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \ auto TENSOR1##_stride = TENSOR1.stride(DIM); \
...@@ -16,7 +17,7 @@ ...@@ -16,7 +17,7 @@
auto TENSOR3##_stride = TENSOR3.stride(DIM); \ auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\ \
auto dims = TENSOR1.dim(); \ auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(torch::CPU(at::kLong), {dims}); \ auto zeros = at::zeros(dims, torch::CPU(at::kLong)); \
auto counter = zeros.data<int64_t>(); \ auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \ bool has_finished = false; \
\ \
...@@ -52,10 +53,12 @@ ...@@ -52,10 +53,12 @@
} else \ } else \
break; \ break; \
} \ } \
} } \
}()
#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \ #define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
TENSOR4, DIM, CODE) \ TENSOR4, DIM, CODE) \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \ TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \ auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \ auto TENSOR1##_stride = TENSOR1.stride(DIM); \
...@@ -73,7 +76,7 @@ ...@@ -73,7 +76,7 @@
auto TENSOR4##_stride = TENSOR4.stride(DIM); \ auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\ \
auto dims = TENSOR1.dim(); \ auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(torch::CPU(at::kLong), {dims}); \ auto zeros = at::zeros(dims, torch::CPU(at::kLong)); \
auto counter = zeros.data<int64_t>(); \ auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \ bool has_finished = false; \
\ \
...@@ -111,4 +114,5 @@ ...@@ -111,4 +114,5 @@
} else \ } else \
break; \ break; \
} \ } \
} } \
}()
...@@ -11,7 +11,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -11,7 +11,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
idx = index_data[i * index_stride]; idx = index_data[i * index_stride];
out_data[idx * out_stride] *= src_data[i * src_stride]; out_data[idx * out_stride] *= src_data[i * src_stride];
} }
}) });
}); });
} }
...@@ -24,7 +24,7 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -24,7 +24,7 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
idx = index_data[i * index_stride]; idx = index_data[i * index_stride];
out_data[idx * out_stride] /= src_data[i * src_stride]; out_data[idx * out_stride] /= src_data[i * src_stride];
} }
}) });
}); });
} }
...@@ -41,7 +41,7 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -41,7 +41,7 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
arg_data[idx * arg_stride] = i; arg_data[idx * arg_stride] = i;
} }
} }
}) });
}); });
} }
...@@ -58,7 +58,7 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -58,7 +58,7 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
arg_data[idx * arg_stride] = i; arg_data[idx * arg_stride] = i;
} }
} }
}) });
}); });
} }
...@@ -74,7 +74,7 @@ void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg, ...@@ -74,7 +74,7 @@ void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg,
out_data[i * out_stride] = grad_data[idx * grad_stride]; out_data[i * out_stride] = grad_data[idx * grad_stride];
} }
} }
}) });
}); });
} }
......
...@@ -4,7 +4,11 @@ from setuptools import setup, find_packages ...@@ -4,7 +4,11 @@ from setuptools import setup, find_packages
import torch.cuda import torch.cuda
from torch.utils.cpp_extension import CppExtension, CUDAExtension from torch.utils.cpp_extension import CppExtension, CUDAExtension
ext_modules = [CppExtension('scatter_cpu', ['cpu/scatter.cpp'])] ext_modules = [
CppExtension(
'scatter_cpu', ['cpu/scatter.cpp'],
extra_compile_args=['-Wno-unused-variable'])
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if torch.cuda.is_available(): if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment