Commit 3e06f342 authored by rusty1s's avatar rusty1s
Browse files

argmax impl

parent bdf2563a
...@@ -51,3 +51,4 @@ def test_scatter_cuda_max(str): ...@@ -51,3 +51,4 @@ def test_scatter_cuda_max(str):
_, arg_output = scatter_max_(output, index, input, dim=1) _, arg_output = scatter_max_(output, index, input, dim=1)
print(output) print(output)
print(arg_output)
...@@ -25,13 +25,23 @@ struct TensorInfo { ...@@ -25,13 +25,23 @@ struct TensorInfo {
for (int I = blockIdx.x * blockDim.x + threadIdx.x; I < N; i += blockDim.x * gridDim.x) for (int I = blockIdx.x * blockDim.x + threadIdx.x; I < N; i += blockDim.x * gridDim.x)
/* #define KERNEL_RUN(NAME, DIMS, N, PARAMS) \ */ /* #define KERNEL_RUN(NAME, DIMS, N, PARAMS) \ */
#define KERNEL_RUN(NAME, DIMS, N, ...) \ #define KERNEL_RUN(NAME, DIMS, N, ...) { \
int grid = GET_BLOCKS(N); \ int grid = GET_BLOCKS(N); \
cudaStream_t stream = THCState_getCurrentStream(state); \ cudaStream_t stream = THCState_getCurrentStream(state); \
switch (DIMS) { \ switch (DIMS) { \
case 1: NAME<real, 1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \ case 1: NAME<real, 1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 2: NAME<real, 2><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \ case 2: NAME<real, 2><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 3: NAME<real, 3><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \ case 3: NAME<real, 3><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
default: NAME<real, -1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \ default: NAME<real, -1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
} \ } \
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError()); \
}
static inline __device__ bool eq(uint8_t a, uint8_t b) { return a == b; }
static inline __device__ bool eq( int8_t a, int8_t b) { return a == b; }
static inline __device__ bool eq(int16_t a, int16_t b) { return a == b; }
static inline __device__ bool eq(int32_t a, int32_t b) { return a == b; }
static inline __device__ bool eq(int64_t a, int64_t b) { return a == b; }
static inline __device__ bool eq( float a, float b) { return a == b; }
static inline __device__ bool eq( double a, double b) { return a == b; }
static inline __device__ bool eq(half a, half b) { return __half2float(a) == __half2float(b); }
...@@ -26,13 +26,21 @@ void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor ...@@ -26,13 +26,21 @@ void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor
TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input); TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
TensorInfo<int64_t> argInfo = thc_getTensorInfo_Long(state, arg); TensorInfo<int64_t> argInfo = thc_getTensorInfo_Long(state, arg);
KERNEL_RUN(maxKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, argInfo, dim) KERNEL_RUN(maxKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
/* KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, dim) */ KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, argInfo, dim)
} }
void scatter_(min)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg) { void scatter_(min)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg) {
thc_(check)(state, output, index, input); thc_(check)(state, output, index, input);
printf("min");
const int n = THCudaLongTensor_nElement(state, index);
TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
TensorInfo<int64_t> argInfo = thc_getTensorInfo_Long(state, arg);
KERNEL_RUN(minKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, argInfo, dim)
} }
void index_backward(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *grad, THCudaLongTensor *arg) { void index_backward(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *grad, THCudaLongTensor *arg) {
......
...@@ -14,12 +14,29 @@ ...@@ -14,12 +14,29 @@
#include "THCGenerateAllTypes.h" #include "THCGenerateAllTypes.h"
template<typename Real, int Dims> template<typename Real, int Dims>
__global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, TensorInfo<int64_t> arg, const int dim, const int n) { __global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;;
IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
atomicMax(&output.data[outputOffset], input.data[inputOffset]);
}
}
template<typename Real, int Dims>
__global__ void minKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;;
IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
atomicMin(&output.data[outputOffset], input.data[inputOffset]);
}
}
template<typename Real, int Dims>
__global__ void argKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, TensorInfo<int64_t> arg, const int dim, const int n) {
KERNEL_LOOP(i, n) { KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int argOffset = 0; int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int argOffset = 0;
IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset, arg, &argOffset); IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset, arg, &argOffset);
atomicMax(&output.data[outputOffset], input.data[inputOffset]); if (eq(input.data[inputOffset], output.data[outputOffset])) arg.data[argOffset] = inputOffset % input.size[dim];
// TODO: Do something with arg.
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment