Commit c9d2aded authored by rusty1s's avatar rusty1s
Browse files

added scatter index offsets

parent 3bed6293
template <typename a, typename b, int Dims>
struct IndexToScatterOffsets {
static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset,
const TensorInfo<b>& t2, int* t2Offset) {
int curDimIndex;
for (int d = Dims - 1; d >= 0; d--) {
curDimIndex = i % index.size[d];
*indexOffset += curDimIndex * index.stride[d];
*t1Offset += curDimIndex * t1.stride[d];
if (d != dim) *t2Offset += curDimIndex * t2.stride[d];
i /= index.size[d];
}
int64_t indexValue = index.data[*indexOffset];
assert(indexValue >= 0 && indexValue < t2.size[dim]);
*t2Offset += indexValue * t2.stride[dim];
}
};
template <typename a, typename b>
struct IndexToScatterOffsets<a, b, -1> {
static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset,
const TensorInfo<b>& t2, int* t2Offset) {
int curDimIndex;
for (int d = index.dims - 1; d >= 0; d--) {
curDimIndex = i % index.size[d];
*indexOffset += curDimIndex * index.stride[d];
*t1Offset += curDimIndex * t1.stride[d];
if (d != dim) *t2Offset += curDimIndex * t2.stride[d];
i /= index.size[d];
}
int64_t indexValue = index.data[*indexOffset];
assert(indexValue >= 0 && indexValue < t2.size[dim]);
*t2Offset += indexValue * t2.stride[dim];
}
};
#include <THC/THC.h>
#include "THCAtomics.cuh"
#include "kernel.h"
#include "common.cuh"
#include "THCIndex.cuh"
#include "THCAtomics.cuh"
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _kernel_, Real)
#define index_backward TH_CONCAT_2(index_backward_kernel_, Real)
#define check TH_CONCAT_2(check_kernel_, Real)
#define thc_(NAME) TH_CONCAT_4(thc_, NAME, _, Real)
#include "generic/common.cu"
......@@ -16,23 +16,8 @@
template<typename Real, int Dims>
__global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, TensorInfo<int64_t> arg, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int argOffset = 0;
int curDimIndex;
for (int d = index.dims - 1; d >= 0; d--) {
curDimIndex = i % index.size[d];
indexOffset += curDimIndex * index.stride[d];
inputOffset += curDimIndex * input.stride[d];
if (d != dim) {
outputOffset += curDimIndex * output.stride[d];
argOffset += curDimIndex * arg.stride[d];
}
i /= index.size[d];
}
int64_t indexValue = index.data[indexOffset];
assert(indexValue >= 0 && indexValue < output.size[dim]);
outputOffset += indexValue * output.stride[dim];
argOffset += indexValue * arg.stride[dim];
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;
IndexToScatterOffsets<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
atomicMax(&output.data[outputOffset], input.data[inputOffset]);
// TODO: Do something with arg.
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment