Commit 19adb8e8 authored by rusty1s's avatar rusty1s
Browse files

scatter offsets for 3 and 4 tensors

parent c9d2aded
template <typename a, typename b, int Dims> template <typename a, typename b, int Dims>
struct IndexToScatterOffsets { struct IndexToScatterOffsets3 {
static __device__ void compute(int i, const int dim, static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset, const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset, const TensorInfo<a>& t1, int* t1Offset,
...@@ -19,7 +19,7 @@ struct IndexToScatterOffsets { ...@@ -19,7 +19,7 @@ struct IndexToScatterOffsets {
}; };
template <typename a, typename b> template <typename a, typename b>
struct IndexToScatterOffsets<a, b, -1> { struct IndexToScatterOffsets3<a, b, -1> {
static __device__ void compute(int i, const int dim, static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset, const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset, const TensorInfo<a>& t1, int* t1Offset,
...@@ -37,3 +37,53 @@ struct IndexToScatterOffsets<a, b, -1> { ...@@ -37,3 +37,53 @@ struct IndexToScatterOffsets<a, b, -1> {
*t2Offset += indexValue * t2.stride[dim]; *t2Offset += indexValue * t2.stride[dim];
} }
}; };
template <typename a, typename b, typename c, int Dims>
struct IndexToScatterOffsets4 {
static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset,
const TensorInfo<b>& t2, int* t2Offset,
const TensorInfo<c>& t3, int* t3Offset) {
int curDimIndex;
for (int d = Dims - 1; d >= 0; d--) {
curDimIndex = i % index.size[d];
*indexOffset += curDimIndex * index.stride[d];
*t1Offset += curDimIndex * t1.stride[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.stride[d];
*t3Offset += curDimIndex * t3.stride[d];
}
i /= index.size[d];
}
int64_t indexValue = index.data[*indexOffset];
assert(indexValue >= 0 && indexValue < t2.size[dim]);
*t2Offset += indexValue * t2.stride[dim];
*t3Offset += indexValue * t3.stride[dim];
}
};
template <typename a, typename b, typename c>
struct IndexToScatterOffsets4<a, b, c, -1> {
static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset,
const TensorInfo<b>& t2, int* t2Offset,
const TensorInfo<c>& t3, int* t3Offset) {
int curDimIndex;
for (int d = index.dims - 1; d >= 0; d--) {
curDimIndex = i % index.size[d];
*indexOffset += curDimIndex * index.stride[d];
*t1Offset += curDimIndex * t1.stride[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.stride[d];
*t3Offset += curDimIndex * t3.stride[d];
}
i /= index.size[d];
}
int64_t indexValue = index.data[*indexOffset];
assert(indexValue >= 0 && indexValue < t2.size[dim]);
*t2Offset += indexValue * t2.stride[dim];
*t3Offset += indexValue * t3.stride[dim];
}
};
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
template<typename Real, int Dims> template<typename Real, int Dims>
__global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, TensorInfo<int64_t> arg, const int dim, const int n) { __global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, TensorInfo<int64_t> arg, const int dim, const int n) {
KERNEL_LOOP(i, n) { KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int argOffset = 0;
IndexToScatterOffsets<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset); IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset, arg, &argOffset);
atomicMax(&output.data[outputOffset], input.data[inputOffset]); atomicMax(&output.data[outputOffset], input.data[inputOffset]);
// TODO: Do something with arg. // TODO: Do something with arg.
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment