Commit 66320686 authored by rusty1s's avatar rusty1s
Browse files

mul, div, mean

parent 3e06f342
...@@ -4,17 +4,36 @@ ...@@ -4,17 +4,36 @@
void scatter_(mul)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) { void scatter_(mul)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
thc_(check)(state, output, index, input); thc_(check)(state, output, index, input);
printf("mul");
const int n = THCudaLongTensor_nElement(state, index);
TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
KERNEL_RUN(mulKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
} }
void scatter_(div)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) { void scatter_(div)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
thc_(check)(state, output, index, input); thc_(check)(state, output, index, input);
printf("div");
const int n = THCudaLongTensor_nElement(state, index);
TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
KERNEL_RUN(divKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
} }
void scatter_(mean)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCTensor *count) { void scatter_(mean)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCTensor *count) {
thc_(check)(state, output, index, input); thc_(check)(state, output, index, input);
printf("mean");
const int n = THCudaLongTensor_nElement(state, index);
TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
TensorInfo<real> countInfo = thc_(getTensorInfo)(state, count);
KERNEL_RUN(meanKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, countInfo, dim)
} }
void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg) { void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg) {
......
...@@ -13,6 +13,34 @@ ...@@ -13,6 +13,34 @@
#include "generic/common.cu" #include "generic/common.cu"
#include "THCGenerateAllTypes.h" #include "THCGenerateAllTypes.h"
template<typename Real, int Dims>
__global__ void mulKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;;
IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
atomicMul(&output.data[outputOffset], input.data[inputOffset]);
}
}
template<typename Real, int Dims>
__global__ void divKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;;
IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
atomicDiv(&output.data[outputOffset], input.data[inputOffset]);
}
}
template<typename Real, int Dims>
__global__ void meanKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, TensorInfo<Real> count, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int countOffset = 0;
IndexToScatterOffsets4<Real, Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset, count, &countOffset);
atomicAdd(&output.data[outputOffset], input.data[inputOffset]);
atomicAdd(&count.data[countOffset], 1);
}
}
template<typename Real, int Dims> template<typename Real, int Dims>
__global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, const int dim, const int n) { __global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, const int dim, const int n) {
KERNEL_LOOP(i, n) { KERNEL_LOOP(i, n) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment