Commit 19adb8e8 authored by rusty1s's avatar rusty1s
Browse files

scatter offsets for 3 and 4 tensors

parent c9d2aded
template <typename a, typename b, int Dims>
struct IndexToScatterOffsets {
struct IndexToScatterOffsets3 {
static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset,
......@@ -19,7 +19,7 @@ struct IndexToScatterOffsets {
};
template <typename a, typename b>
struct IndexToScatterOffsets<a, b, -1> {
struct IndexToScatterOffsets3<a, b, -1> {
static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset,
......@@ -37,3 +37,53 @@ struct IndexToScatterOffsets<a, b, -1> {
*t2Offset += indexValue * t2.stride[dim];
}
};
template <typename a, typename b, typename c, int Dims>
struct IndexToScatterOffsets4 {
static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset,
const TensorInfo<b>& t2, int* t2Offset,
const TensorInfo<c>& t3, int* t3Offset) {
int curDimIndex;
for (int d = Dims - 1; d >= 0; d--) {
curDimIndex = i % index.size[d];
*indexOffset += curDimIndex * index.stride[d];
*t1Offset += curDimIndex * t1.stride[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.stride[d];
*t3Offset += curDimIndex * t3.stride[d];
}
i /= index.size[d];
}
int64_t indexValue = index.data[*indexOffset];
assert(indexValue >= 0 && indexValue < t2.size[dim]);
*t2Offset += indexValue * t2.stride[dim];
*t3Offset += indexValue * t3.stride[dim];
}
};
template <typename a, typename b, typename c>
struct IndexToScatterOffsets4<a, b, c, -1> {
static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset,
const TensorInfo<b>& t2, int* t2Offset,
const TensorInfo<c>& t3, int* t3Offset) {
int curDimIndex;
for (int d = index.dims - 1; d >= 0; d--) {
curDimIndex = i % index.size[d];
*indexOffset += curDimIndex * index.stride[d];
*t1Offset += curDimIndex * t1.stride[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.stride[d];
*t3Offset += curDimIndex * t3.stride[d];
}
i /= index.size[d];
}
int64_t indexValue = index.data[*indexOffset];
assert(indexValue >= 0 && indexValue < t2.size[dim]);
*t2Offset += indexValue * t2.stride[dim];
*t3Offset += indexValue * t3.stride[dim];
}
};
......@@ -16,8 +16,8 @@
template<typename Real, int Dims>
__global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, TensorInfo<int64_t> arg, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;
IndexToScatterOffsets<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int argOffset = 0;
IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset, arg, &argOffset);
atomicMax(&output.data[outputOffset], input.data[inputOffset]);
// TODO: Do something with arg.
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment