"vscode:/vscode.git/clone" did not exist on "43f5f42759b4226919761b635489ba2abd1b0ff9"
Commit 38c8b3ac authored by rusty1s's avatar rusty1s
Browse files

cuda kernel

parent a2f18da3
#define ATOMIC(NAME) \
template <typename scalar, size_t size> struct Atomic##NAME##IntegerImpl; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 1> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = (uint32_t *)(address - ((size_t)address & 3)); \
uint32_t old = *address_as_ui; \
uint32_t shift = ((size_t)address & 3) * 8; \
uint32_t sum; \
uint32_t assumed; \
\
do { \
assumed = old; \
sum = OP(val, scalar((old >> shift) & 0xff)); \
old = (old & ~(0x000000ff << shift)) | (sum << shift); \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 2> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = \
(uint32_t *)((char *)address - ((size_t)address & 2)); \
uint32_t old = *address_as_ui; \
uint32_t sum; \
uint32_t newval; \
uint32_t assumed; \
\
do { \
assumed = old; \
sum = OP(val, (size_t)address & 2 ? scalar(old >> 16) \
: scalar(old & 0xffff)); \
newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) \
: (old & 0xffff0000) | sum; \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = (uint32_t *)address; \
uint32_t old = *address_as_ui; \
uint32_t assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ui, assumed, OP(val, (scalar)old)); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 8> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned long long *address_as_ull = (unsigned long long *)address; \
unsigned long long old = *address_as_ull; \
unsigned long long assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ull, assumed, OP(val, (scalar)old)); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
int *address_as_i = (int *)address; \
int old = *address_as_i; \
int assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_i, assumed, \
__float_as_int(OP(val, __int_as_float(assumed)))); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 8> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned long long int *address_as_ull = \
(unsigned long long int *)address; \
unsigned long long int old = *address_as_ull; \
unsigned long long int assumed; \
\
do { \
assumed = old; \
old = atomicCAS( \
address_as_ull, assumed, \
__double_as_longlong(OP(val, __longlong_as_double(assumed)))); \
} while (assumed != old); \
} \
};
#define OP(X, Y) Y + X
ATOMIC(Add)
#undef OP
static inline __device__ void atomAdd(uint8_t *address, uint8_t val) {
AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomAdd(int8_t *address, int8_t val) {
AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomAdd(int16_t *address, int16_t val) {
AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomAdd(int32_t *address, int32_t val) {
atomicAdd(address, val);
}
static inline __device__ void atomAdd(int64_t *address, int64_t val) {
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static inline __device__ void atomAdd(double *address, double val) {
AtomicAddDecimalImpl<double, sizeof(double)>()(address, val);
}
#else
static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val);
}
#endif
#define OP(X, Y) Y *X
ATOMIC(Mul)
#undef OP
static inline __device__ void atomMul(uint8_t *address, uint8_t val) {
AtomicMulIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMul(int8_t *address, int8_t val) {
AtomicMulIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMul(int16_t *address, int16_t val) {
AtomicMulIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMul(int32_t *address, int32_t val) {
AtomicMulIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
}
static inline __device__ void atomMul(int64_t *address, int64_t val) {
AtomicMulIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMul(float *address, float val) {
AtomicMulDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMul(double *address, double val) {
AtomicMulDecimalImpl<double, sizeof(double)>()(address, val);
}
#define OP(X, Y) Y / X
ATOMIC(Div)
#undef OP
static inline __device__ void atomDiv(uint8_t *address, uint8_t val) {
AtomicDivIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomDiv(int8_t *address, int8_t val) {
AtomicDivIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomDiv(int16_t *address, int16_t val) {
AtomicDivIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomDiv(int32_t *address, int32_t val) {
AtomicDivIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
}
static inline __device__ void atomDiv(int64_t *address, int64_t val) {
AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomDiv(float *address, float val) {
AtomicDivDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomDiv(double *address, double val) {
AtomicDivDecimalImpl<double, sizeof(double)>()(address, val);
}
#define OP(X, Y) max(Y, X)
ATOMIC(Max)
#undef OP
static inline __device__ void atomMax(uint8_t *address, uint8_t val) {
AtomicMaxIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMax(int8_t *address, int8_t val) {
AtomicMaxIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMax(int16_t *address, int16_t val) {
AtomicMaxIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMax(int32_t *address, int32_t val) {
atomicMax(address, val);
}
static inline __device__ void atomMax(int64_t *address, int64_t val) {
AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMax(float *address, float val) {
AtomicMaxDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMax(double *address, double val) {
AtomicMaxDecimalImpl<double, sizeof(double)>()(address, val);
}
#define OP(X, Y) min(Y, X)
ATOMIC(Min)
#undef OP
static inline __device__ void atomMin(uint8_t *address, uint8_t val) {
AtomicMinIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMin(int8_t *address, int8_t val) {
AtomicMinIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMin(int16_t *address, int16_t val) {
AtomicMinIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMin(int32_t *address, int32_t val) {
atomicMin(address, val);
}
static inline __device__ void atomMin(int64_t *address, int64_t val) {
AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMin(float *address, float val) {
AtomicMinDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMin(double *address, double val) {
AtomicMinDecimalImpl<double, sizeof(double)>()(address, val);
}
#pragma once
#include <ATen/ATen.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
template <typename scalar1, typename scalar2, int64_t Dims>
struct IndexToScatterOffsets3 {
static __device__ void
compute(int64_t i, const int64_t dim,
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
int64_t *indexOffset,
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
int64_t *t1Offset,
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
int64_t *t2Offset) {
for (int64_t d = Dims - 1; d >= 0; d--) {
int64_t curDimIndex = i % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
}
i /= index.sizes[d];
}
int64_t indexValue = index.data[*indexOffset];
*t2Offset += indexValue * t2.strides[dim];
}
};
template <typename scalar1, typename scalar2>
struct IndexToScatterOffsets3<scalar1, scalar2, -1> {
static __device__ void
compute(int64_t i, const int64_t dim,
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
int64_t *indexOffset,
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
int64_t *t1Offset,
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
int64_t *t2Offset) {
for (int64_t d = index.dims - 1; d >= 0; d--) {
int64_t curDimIndex = i % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
}
i /= index.sizes[d];
}
int64_t indexValue = index.data[*indexOffset];
*t2Offset += indexValue * t2.strides[dim];
}
};
template <typename scalar1, typename scalar2, typename scalar3, int64_t Dims>
struct IndexToScatterOffsets4 {
static __device__ void
compute(int64_t i, const int64_t dim,
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
int64_t *indexOffset,
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
int64_t *t1Offset,
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
int64_t *t2Offset,
const at::cuda::detail::TensorInfo<scalar3, int64_t> &t3,
int64_t *t3Offset) {
for (int64_t d = Dims - 1; d >= 0; d--) {
int64_t curDimIndex = i % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
*t3Offset += curDimIndex * t3.strides[d];
}
i /= index.sizes[d];
}
int64_t indexValue = index.data[*indexOffset];
*t2Offset += indexValue * t2.strides[dim];
*t3Offset += indexValue * t3.strides[dim];
}
};
template <typename scalar1, typename scalar2, typename scalar3>
struct IndexToScatterOffsets4<scalar1, scalar2, scalar3, -1> {
static __device__ void
compute(int64_t i, const int64_t dim,
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
int64_t *indexOffset,
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
int64_t *t1Offset,
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
int64_t *t2Offset,
const at::cuda::detail::TensorInfo<scalar3, int64_t> &t3,
int64_t *t3Offset) {
for (int64_t d = index.dims - 1; d >= 0; d--) {
int64_t curDimIndex = i % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
*t3Offset += curDimIndex * t3.strides[d];
}
i /= index.sizes[d];
}
int64_t indexValue = index.data[*indexOffset];
*t2Offset += indexValue * t2.strides[dim];
*t3Offset += indexValue * t3.strides[dim];
}
};
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim);
void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim);
void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim);
void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim);
void index_backward_cuda(at::Tensor grad, at::Tensor index, at::Tensor arg,
at::Tensor out, int64_t dim);
void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
scatter_mul_cuda(src, index, out, dim);
}
void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
scatter_div_cuda(src, index, out, dim);
}
void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
CHECK_CUDA(arg);
scatter_max_cuda(src, index, out, arg, dim);
}
void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
CHECK_CUDA(arg);
scatter_min_cuda(src, index, out, arg, dim);
}
void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg,
at::Tensor out, int64_t dim) {
CHECK_CUDA(grad);
CHECK_CUDA(index);
CHECK_CUDA(arg);
CHECK_CUDA(out);
index_backward_cuda(grad, index, arg, out, dim);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scatter_mul", &scatter_mul, "Scatter Mul (CUDA)");
m.def("scatter_div", &scatter_div, "Scatter Div (CUDA)");
m.def("scatter_max", &scatter_max, "Scatter Max (CUDA)");
m.def("scatter_min", &scatter_min, "Scatter Min (CUDA)");
m.def("index_backward", &index_backward, "Index Backward (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "atomics.cuh"
#include "index.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define KERNEL_RUN(NAME, DIMS, N, ...) \
[&] { \
switch (DIMS) { \
case 1: \
NAME<scalar_t, 1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
break; \
case 2: \
NAME<scalar_t, 2><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
break; \
case 3: \
NAME<scalar_t, 3><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
break; \
default: \
NAME<scalar_t, -1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
} \
}()
template <typename scalar_t, int64_t Dims>
__global__ void
scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
atomMul(&out.data[outOffset], src.data[srcOffset]);
}
}
void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_mul_kernel", [&] {
KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
});
}
template <typename scalar_t, int64_t Dims>
__global__ void
scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
atomDiv(&out.data[outOffset], src.data[srcOffset]);
}
}
void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_div_kernel", [&] {
KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
});
}
template <typename scalar_t, int64_t Dims>
__global__ void arg_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
at::cuda::detail::TensorInfo<int64_t, int64_t> arg,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0, argOffset = 0;
IndexToScatterOffsets4<scalar_t, scalar_t, int64_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset, arg,
&argOffset);
if (src.data[srcOffset] == out.data[outOffset]) {
arg.data[argOffset] = (srcOffset / src.strides[dim]) % src.sizes[dim];
}
}
}
template <typename scalar_t, int64_t Dims>
__global__ void
scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
atomMax(&out.data[outOffset], src.data[srcOffset]);
}
}
void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_max_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
KERNEL_RUN(scatter_max_kernel, index.dim(), index.numel(), src_info,
index_info, out_info, dim);
KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info,
out_info, at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
dim);
});
}
template <typename scalar_t, int64_t Dims>
__global__ void
scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
atomMin(&out.data[outOffset], src.data[srcOffset]);
}
}
void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_min_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
KERNEL_RUN(scatter_min_kernel, index.dim(), index.numel(), src_info,
index_info, out_info, dim);
KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info,
out_info, at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
dim);
});
}
template <typename scalar_t, int64_t Dims>
__global__ void
index_backward_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<int64_t, int64_t> arg,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t gradOffset = 0, indexOffset = 0, argOffset = 0, outOffset = 0;
IndexToScatterOffsets4<scalar_t, int64_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, out, &outOffset, arg, &argOffset, grad,
&gradOffset);
if (arg.data[argOffset] ==
(outOffset / out.strides[dim]) % out.sizes[dim]) {
out.data[outOffset] = grad.data[gradOffset];
}
}
}
void index_backward_cuda(at::Tensor grad, at::Tensor index, at::Tensor arg,
at::Tensor out, int64_t dim) {
AT_DISPATCH_ALL_TYPES(grad.type(), "index_backward_kernel", [&] {
KERNEL_RUN(index_backward_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
});
}
......@@ -5,9 +5,8 @@ import torch
from torch.autograd import gradcheck
import torch_scatter
from .utils import devices, tensor
from .utils import grad_dtypes as dtypes, devices, tensor
dtypes = [torch.float, torch.double]
funcs = ['add', 'sub', 'mul', 'div', 'mean']
indices = [2, 0, 1, 1, 0]
......
......@@ -3,6 +3,9 @@ from torch.testing import get_all_dtypes
dtypes = get_all_dtypes()
dtypes.remove(torch.half)
dtypes.remove(torch.short) # PyTorch scatter does not work on short types.
grad_dtypes = [torch.float, torch.double]
devices = [torch.device('cpu')]
if torch.cuda.is_available(): # pragma: no cover
......
from .utils.gen import gen
from torch_scatter.utils.gen import gen
def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
......
from torch.autograd import Function
from .utils.ext import get_func
from .utils.gen import gen
from torch_scatter.utils.ext import get_func
from torch_scatter.utils.gen import gen
class ScatterDiv(Function):
......
#define ATOMIC_(NAME) \
template <typename T, size_t n> \
struct TH_CONCAT_3(Atomic, NAME, IntegerImpl); \
\
template<typename T> \
struct TH_CONCAT_3(Atomic, NAME, IntegerImpl)<T, 1> { \
inline __device__ void operator()(T *address, T val) { \
uint32_t *address_as_ui = (uint32_t *) (address - ((size_t) address & 3)); \
uint32_t old = *address_as_ui; \
uint32_t shift = ((size_t) address & 3) * 8; \
uint32_t res; \
uint32_t assumed; \
\
do { \
assumed = old; \
res = OP(val, T((old >> shift) & 0xff)); \
old = (old & ~(0x000000ff << shift)) | (res << shift); \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template<typename T> \
struct TH_CONCAT_3(Atomic, NAME, IntegerImpl)<T, 2> { \
inline __device__ void operator()(T *address, T val) { \
uint32_t *address_as_ui = (uint32_t *) ((char *) address - ((size_t) address & 2)); \
uint32_t old = *address_as_ui; \
uint32_t res; \
uint32_t newval; \
uint32_t assumed; \
\
do { \
assumed = old; \
res = OP(val, (size_t) address & 2 ? T(old >> 16) : T(old & 0xffff)); \
newval = (size_t) address & 2 ? (old & 0xffff) | (res << 16) : (old & 0xffff0000) | res; \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
} \
}; \
\
template<typename T> \
struct TH_CONCAT_3(Atomic, NAME, IntegerImpl)<T, 4> { \
inline __device__ void operator()(T *address, T val) { \
uint32_t *address_as_ui = (uint32_t *) address; \
uint32_t old = *address_as_ui; \
uint32_t assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ui, assumed, OP(val, (T) old)); \
} while (assumed != old); \
} \
}; \
\
template<typename T> \
struct TH_CONCAT_3(Atomic, NAME, IntegerImpl)<T, 8> { \
inline __device__ void operator()(T *address, T val) { \
unsigned long long *address_as_ull = (unsigned long long *) address; \
unsigned long long old = *address_as_ull; \
unsigned long long assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ull, assumed, OP(val, (T) old)); \
} while (assumed != old); \
} \
}; \
\
template <typename T, size_t n> \
struct TH_CONCAT_3(Atomic, NAME, DecimalImpl); \
\
template <typename T> \
struct TH_CONCAT_3(Atomic, NAME, DecimalImpl)<T, 4> { \
inline __device__ void operator()(T *address, T val) { \
int *address_as_i = (int *) address; \
int old = *address_as_i; \
int assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_i, assumed, __float_as_int(OP(val, __int_as_float(assumed)))); \
} while (assumed != old); \
} \
}; \
\
template <typename T> \
struct TH_CONCAT_3(Atomic, NAME, DecimalImpl)<T, 8> { \
inline __device__ void operator()(T *address, T val) { \
unsigned long long int *address_as_ull = (unsigned long long int *) address; \
unsigned long long int old = *address_as_ull; \
unsigned long long int assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(OP(val, __longlong_as_double(assumed)))); \
} while (assumed != old); \
} \
};
#define OP(X, Y) Y + X
ATOMIC_(Add)
#undef OP
static inline __device__ void atomAdd(uint8_t *address, uint8_t val) { AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
static inline __device__ void atomAdd( int8_t *address, int8_t val) { AtomicAddIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
static inline __device__ void atomAdd(int16_t *address, int16_t val) { AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
static inline __device__ void atomAdd(int32_t *address, int32_t val) { atomicAdd(address, val); }
static inline __device__ void atomAdd(int64_t *address, int64_t val) { AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
static inline __device__ void atomAdd( float *address, float val) { atomicAdd(address, val); }
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static inline __device__ void atomAdd( double *address, double val) { AtomicAddDecimalImpl< double, sizeof( double)>()(address, val); }
#else
static inline __device__ void atomAdd( double *address, double val) { atomicAdd(address, val); }
#endif
#define OP(X, Y) Y * X
ATOMIC_(Mul)
#undef OP
static inline __device__ void atomMul(uint8_t *address, uint8_t val) { AtomicMulIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
static inline __device__ void atomMul( int8_t *address, int8_t val) { AtomicMulIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
static inline __device__ void atomMul(int16_t *address, int16_t val) { AtomicMulIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
static inline __device__ void atomMul(int32_t *address, int32_t val) { AtomicMulIntegerImpl<int32_t, sizeof(int32_t)>()(address, val); }
static inline __device__ void atomMul(int64_t *address, int64_t val) { AtomicMulIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
static inline __device__ void atomMul( float *address, float val) { AtomicMulDecimalImpl< float, sizeof( float)>()(address, val); }
static inline __device__ void atomMul( double *address, double val) { AtomicMulDecimalImpl< double, sizeof( double)>()(address, val); }
#define OP(X, Y) Y / X
ATOMIC_(Div)
#undef OP
static inline __device__ void atomDiv(uint8_t *address, uint8_t val) { AtomicDivIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
static inline __device__ void atomDiv( int8_t *address, int8_t val) { AtomicDivIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
static inline __device__ void atomDiv(int16_t *address, int16_t val) { AtomicDivIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
static inline __device__ void atomDiv(int32_t *address, int32_t val) { AtomicDivIntegerImpl<int32_t, sizeof(int32_t)>()(address, val); }
static inline __device__ void atomDiv(int64_t *address, int64_t val) { AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
static inline __device__ void atomDiv( float *address, float val) { AtomicDivDecimalImpl< float, sizeof( float)>()(address, val); }
static inline __device__ void atomDiv( double *address, double val) { AtomicDivDecimalImpl< double, sizeof( double)>()(address, val); }
#define OP(X, Y) max(Y, X)
ATOMIC_(Max)
#undef OP
static inline __device__ void atomMax(uint8_t *address, uint8_t val) { AtomicMaxIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
static inline __device__ void atomMax( int8_t *address, int8_t val) { AtomicMaxIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
static inline __device__ void atomMax(int16_t *address, int16_t val) { AtomicMaxIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
static inline __device__ void atomMax(int32_t *address, int32_t val) { atomicMax(address, val); }
static inline __device__ void atomMax(int64_t *address, int64_t val) { AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
static inline __device__ void atomMax( float *address, float val) { AtomicMaxDecimalImpl< float, sizeof( float)>()(address, val); }
static inline __device__ void atomMax( double *address, double val) { AtomicMaxDecimalImpl< double, sizeof( double)>()(address, val); }
#define OP(X, Y) min(Y, X)
ATOMIC_(Min)
#undef OP
static inline __device__ void atomMin(uint8_t *address, uint8_t val) { AtomicMinIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
static inline __device__ void atomMin( int8_t *address, int8_t val) { AtomicMinIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
static inline __device__ void atomMin(int16_t *address, int16_t val) { AtomicMinIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
static inline __device__ void atomMin(int32_t *address, int32_t val) { atomicMin(address, val); }
static inline __device__ void atomMin(int64_t *address, int64_t val) { AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
static inline __device__ void atomMin( float *address, float val) { AtomicMinDecimalImpl< float, sizeof( float)>()(address, val); }
static inline __device__ void atomMin( double *address, double val) { AtomicMinDecimalImpl< double, sizeof( double)>()(address, val); }
template <typename a, typename b, int Dims>
struct IndexToScatterOffsets3 {
static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset,
const TensorInfo<b>& t2, int* t2Offset) {
int curDimIndex;
for (int d = Dims - 1; d >= 0; d--) {
curDimIndex = i % index.size[d];
*indexOffset += curDimIndex * index.stride[d];
*t1Offset += curDimIndex * t1.stride[d];
if (d != dim) *t2Offset += curDimIndex * t2.stride[d];
i /= index.size[d];
}
int64_t indexValue = index.data[*indexOffset];
assert(indexValue >= 0 && indexValue < t2.size[dim]);
*t2Offset += indexValue * t2.stride[dim];
}
};
template <typename a, typename b>
struct IndexToScatterOffsets3<a, b, -1> {
static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset,
const TensorInfo<b>& t2, int* t2Offset) {
int curDimIndex;
for (int d = index.dims - 1; d >= 0; d--) {
curDimIndex = i % index.size[d];
*indexOffset += curDimIndex * index.stride[d];
*t1Offset += curDimIndex * t1.stride[d];
if (d != dim) *t2Offset += curDimIndex * t2.stride[d];
i /= index.size[d];
}
int64_t indexValue = index.data[*indexOffset];
assert(indexValue >= 0 && indexValue < t2.size[dim]);
*t2Offset += indexValue * t2.stride[dim];
}
};
template <typename a, typename b, typename c, int Dims>
struct IndexToScatterOffsets4 {
static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset,
const TensorInfo<b>& t2, int* t2Offset,
const TensorInfo<c>& t3, int* t3Offset) {
int curDimIndex;
for (int d = Dims - 1; d >= 0; d--) {
curDimIndex = i % index.size[d];
*indexOffset += curDimIndex * index.stride[d];
*t1Offset += curDimIndex * t1.stride[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.stride[d];
*t3Offset += curDimIndex * t3.stride[d];
}
i /= index.size[d];
}
int64_t indexValue = index.data[*indexOffset];
assert(indexValue >= 0 && indexValue < t2.size[dim]);
*t2Offset += indexValue * t2.stride[dim];
*t3Offset += indexValue * t3.stride[dim];
}
};
template <typename a, typename b, typename c>
struct IndexToScatterOffsets4<a, b, c, -1> {
static __device__ void compute(int i, const int dim,
const TensorInfo<int64_t>& index, int* indexOffset,
const TensorInfo<a>& t1, int* t1Offset,
const TensorInfo<b>& t2, int* t2Offset,
const TensorInfo<c>& t3, int* t3Offset) {
int curDimIndex;
for (int d = index.dims - 1; d >= 0; d--) {
curDimIndex = i % index.size[d];
*indexOffset += curDimIndex * index.stride[d];
*t1Offset += curDimIndex * t1.stride[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.stride[d];
*t3Offset += curDimIndex * t3.stride[d];
}
i /= index.size[d];
}
int64_t indexValue = index.data[*indexOffset];
assert(indexValue >= 0 && indexValue < t2.size[dim]);
*t2Offset += indexValue * t2.stride[dim];
*t3Offset += indexValue * t3.stride[dim];
}
};
const int MAX_DIMS = 25;
const int NUM_THREADS = 1024;
inline int GET_BLOCKS(const int n) {
return (n + NUM_THREADS - 1) / NUM_THREADS;
}
template<typename T>
struct TensorInfo {
TensorInfo(T *t, int d, int sz[MAX_DIMS], int st[MAX_DIMS]) {
data = t; dims = d;
for (int i = 0; i < dims; i++) {
size[i] = sz[i];
stride[i] = st[i];
}
}
T *data;
int dims;
int size[MAX_DIMS];
int stride[MAX_DIMS];
};
#define KERNEL_LOOP(I, N) \
for (int I = blockIdx.x * blockDim.x + threadIdx.x; I < N; i += blockDim.x * gridDim.x)
#define KERNEL_RUN(NAME, DIMS, N, ...) { \
int grid = GET_BLOCKS(N); \
cudaStream_t stream = THCState_getCurrentStream(state); \
switch (DIMS) { \
case 1: NAME<real, 1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 2: NAME<real, 2><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 3: NAME<real, 3><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
default: NAME<real, -1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
} \
THCudaCheck(cudaGetLastError()); \
}
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/common.cu"
#else
void thc_(check)(THCState *state, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, output, input));
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index));
THArgCheck(THCTensor_(nDimension)(state, output) <= MAX_DIMS, 1, "Tensor too large or too many dimensions");
}
TensorInfo<real> thc_(getTensorInfo)(THCState *state, THCTensor *tensor) {
real *data = THCTensor_(data)(state, tensor);
int dims = THCTensor_(nDimension)(state, tensor);
int size[MAX_DIMS]; int stride[MAX_DIMS];
for (int i = 0; i < dims; i++) {
size[i] = THCTensor_(size)(state, tensor, i);
stride[i] = THCTensor_(stride)(state, tensor, i);
}
return TensorInfo<real>(data, dims, size, stride);
}
#endif
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/kernel.cu"
#else
void scatter_(mul)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
thc_(check)(state, output, index, input);
const int n = THCudaLongTensor_nElement(state, index);
TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
KERNEL_RUN(mulKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
}
void scatter_(div)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
thc_(check)(state, output, index, input);
const int n = THCudaLongTensor_nElement(state, index);
TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
KERNEL_RUN(divKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
}
void scatter_(mean)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCTensor *count) {
thc_(check)(state, output, index, input);
const int n = THCudaLongTensor_nElement(state, index);
TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
TensorInfo<real> countInfo = thc_(getTensorInfo)(state, count);
KERNEL_RUN(meanKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, countInfo, dim)
}
void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg) {
thc_(check)(state, output, index, input);
const int n = THCudaLongTensor_nElement(state, index);
TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
TensorInfo<int64_t> argInfo = thc_getTensorInfo_Long(state, arg);
KERNEL_RUN(maxKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, argInfo, dim)
}
void scatter_(min)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg) {
thc_(check)(state, output, index, input);
const int n = THCudaLongTensor_nElement(state, index);
TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
TensorInfo<int64_t> argInfo = thc_getTensorInfo_Long(state, arg);
KERNEL_RUN(minKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, argInfo, dim)
}
void index_backward(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *grad, THCudaLongTensor *arg) {
thc_(check)(state, output, index, grad);
const int n = THCudaLongTensor_nElement(state, index);
TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
TensorInfo<real> gradInfo = thc_(getTensorInfo)(state, grad);
TensorInfo<int64_t> argInfo = thc_getTensorInfo_Long(state, arg);
KERNEL_RUN(indexBackwardKernel, indexInfo.dims, n, outputInfo, indexInfo, gradInfo, argInfo, dim)
}
#endif
#include <THC.h>
#include "kernel.h"
#include "common.cuh"
#include "THCIndex.cuh"
#include "THCAtomics.cuh"
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _kernel_, Real)
#define index_backward TH_CONCAT_2(index_backward_kernel_, Real)
#define thc_(NAME) TH_CONCAT_4(thc_, NAME, _, Real)
#include "generic/common.cu"
#include "THCGenerateAllTypes.h"
template<typename Real, int Dims>
__global__ void mulKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;
IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
atomMul(&output.data[outputOffset], input.data[inputOffset]);
}
}
template<typename Real, int Dims>
__global__ void divKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;
IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
atomDiv(&output.data[outputOffset], input.data[inputOffset]);
}
}
template<typename Real, int Dims>
__global__ void meanKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, TensorInfo<Real> count, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int countOffset = 0;
IndexToScatterOffsets4<Real, Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset, count, &countOffset);
atomAdd(&output.data[outputOffset], input.data[inputOffset]);
atomAdd(&count.data[countOffset], 1);
}
}
template<typename Real, int Dims>
__global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;
IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
atomMax(&output.data[outputOffset], input.data[inputOffset]);
}
}
template<typename Real, int Dims>
__global__ void minKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;
IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
atomMin(&output.data[outputOffset], input.data[inputOffset]);
}
}
template<typename Real, int Dims>
__global__ void argKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, TensorInfo<int64_t> arg, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int argOffset = 0;
IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset, arg, &argOffset);
if (input.data[inputOffset] == output.data[outputOffset]) {
arg.data[argOffset] = (inputOffset / input.stride[dim]) % input.size[dim];
}
}
}
template<typename Real, int Dims>
__global__ void indexBackwardKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> grad, TensorInfo<int64_t> arg, const int dim, const int n) {
KERNEL_LOOP(i, n) {
int outputOffset = 0; int indexOffset = 0; int gradOffset = 0; int argOffset = 0;
IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, output, &outputOffset, grad, &gradOffset, arg, &argOffset);
if (arg.data[argOffset] == (outputOffset / output.stride[dim]) % output.size[dim]) {
output.data[outputOffset] = grad.data[gradOffset];
}
}
}
#include "generic/kernel.cu"
#include "THCGenerateFloatType.h"
#include "generic/kernel.cu"
#include "THCGenerateDoubleType.h"
#include "generic/kernel.cu"
#include "THCGenerateByteType.h"
#include "generic/kernel.cu"
#include "THCGenerateCharType.h"
#include "generic/kernel.cu"
#include "THCGenerateShortType.h"
#include "generic/kernel.cu"
#include "THCGenerateIntType.h"
#include "generic/kernel.cu"
#include "THCGenerateLongType.h"
#ifdef __cplusplus
extern "C" {
#endif
void scatter_mul_kernel_Float (THCState *state, int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input);
void scatter_mul_kernel_Double(THCState *state, int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input);
void scatter_mul_kernel_Byte (THCState *state, int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input);
void scatter_mul_kernel_Char (THCState *state, int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input);
void scatter_mul_kernel_Short (THCState *state, int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input);
void scatter_mul_kernel_Int (THCState *state, int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input);
void scatter_mul_kernel_Long (THCState *state, int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input);
void scatter_div_kernel_Float (THCState *state, int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input);
void scatter_div_kernel_Double(THCState *state, int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input);
void scatter_div_kernel_Byte (THCState *state, int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input);
void scatter_div_kernel_Char (THCState *state, int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input);
void scatter_div_kernel_Short (THCState *state, int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input);
void scatter_div_kernel_Int (THCState *state, int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input);
void scatter_div_kernel_Long (THCState *state, int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input);
void scatter_mean_kernel_Float (THCState *state, int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaTensor *count);
void scatter_mean_kernel_Double(THCState *state, int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaDoubleTensor *count);
void scatter_mean_kernel_Byte (THCState *state, int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaByteTensor *count);
void scatter_mean_kernel_Char (THCState *state, int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaCharTensor *count);
void scatter_mean_kernel_Short (THCState *state, int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaShortTensor *count);
void scatter_mean_kernel_Int (THCState *state, int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaIntTensor *count);
void scatter_mean_kernel_Long (THCState *state, int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *count);
void scatter_max_kernel_Float (THCState *state, int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaLongTensor *arg);
void scatter_max_kernel_Double(THCState *state, int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaLongTensor *arg);
void scatter_max_kernel_Byte (THCState *state, int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaLongTensor *arg);
void scatter_max_kernel_Char (THCState *state, int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaLongTensor *arg);
void scatter_max_kernel_Short (THCState *state, int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaLongTensor *arg);
void scatter_max_kernel_Int (THCState *state, int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaLongTensor *arg);
void scatter_max_kernel_Long (THCState *state, int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *arg);
void scatter_min_kernel_Float (THCState *state, int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaLongTensor *arg);
void scatter_min_kernel_Double(THCState *state, int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaLongTensor *arg);
void scatter_min_kernel_Byte (THCState *state, int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaLongTensor *arg);
void scatter_min_kernel_Char (THCState *state, int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaLongTensor *arg);
void scatter_min_kernel_Short (THCState *state, int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaLongTensor *arg);
void scatter_min_kernel_Int (THCState *state, int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaLongTensor *arg);
void scatter_min_kernel_Long (THCState *state, int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *arg);
void index_backward_kernel_Float (THCState *state, int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *grad, THCudaLongTensor *arg);
void index_backward_kernel_Double(THCState *state, int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *grad, THCudaLongTensor *arg);
void index_backward_kernel_Byte (THCState *state, int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *grad, THCudaLongTensor *arg);
void index_backward_kernel_Char (THCState *state, int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *grad, THCudaLongTensor *arg);
void index_backward_kernel_Short (THCState *state, int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *grad, THCudaLongTensor *arg);
void index_backward_kernel_Int (THCState *state, int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *grad, THCudaLongTensor *arg);
void index_backward_kernel_Long (THCState *state, int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *grad, THCudaLongTensor *arg);
#ifdef __cplusplus
}
#endif
from torch.autograd import Function
from .utils.ext import get_func
from .utils.gen import gen
from torch_scatter.utils.ext import get_func
from torch_scatter.utils.gen import gen
class ScatterMax(Function):
......
import torch
from .add import scatter_add
from torch_scatter import scatter_add
def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
......
from torch.autograd import Function
from .utils.ext import get_func
from .utils.gen import gen
from torch_scatter.utils.ext import get_func
from torch_scatter.utils.gen import gen
class ScatterMin(Function):
......
from torch.autograd import Function
from .utils.ext import get_func
from .utils.gen import gen
from torch_scatter.utils.ext import get_func
from torch_scatter.utils.gen import gen
class ScatterMul(Function):
......
from .add import scatter_add
from torch_scatter import scatter_add
def scatter_sub(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment