Commit 5be6d63a authored by rusty1s's avatar rusty1s
Browse files

scatter kernel done

parent 5e2d0f1f
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#pragma once
#include <torch/extension.h>
#include "compat.h"
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\
TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\
TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros = torch::zeros(dims, TENSOR1.options().dtype(torch::kLong)); \
auto counter = zeros.DATA_PTR<int64_t>(); \
bool has_finished = false; \
\
while (!has_finished) { \
CODE; \
if (dims == 1) \
break; \
\
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} \
continue; \
} \
\
counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \
\
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} \
} \
}()
#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
TENSOR4, DIM, CODE) \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\
TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\
TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
TYPE4 *TENSOR4##_data = TENSOR4.DATA_PTR<TYPE4>(); \
auto TENSOR4##_size = TENSOR4.size(DIM); \
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros = torch::zeros(dims, TENSOR1.options().dtype(torch::kLong)); \
auto counter = zeros.DATA_PTR<int64_t>(); \
bool has_finished = false; \
\
while (!has_finished) { \
CODE; \
if (dims == 1) \
break; \
\
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} \
continue; \
} \
\
counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \
TENSOR4##_data += TENSOR4.stride(cur_dim); \
\
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
TENSOR4##_data -= counter[cur_dim] * TENSOR4.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} \
} \
}()
#pragma once
#include <torch/extension.h>
#include "compat.h"
#define MAX_TENSORINFO_DIMS 25
template <typename scalar_t> struct TensorInfo {
TensorInfo(scalar_t *p, int dim, int sz[MAX_TENSORINFO_DIMS],
int st[MAX_TENSORINFO_DIMS]) {
data = p;
dims = dim;
AT_ASSERT(dims < MAX_TENSORINFO_DIMS);
for (int i = 0; i < dim; ++i) {
sizes[i] = sz[i];
strides[i] = st[i];
}
}
scalar_t *data;
int dims;
int sizes[MAX_TENSORINFO_DIMS];
int strides[MAX_TENSORINFO_DIMS];
};
template <typename scalar_t>
TensorInfo<scalar_t> getTensorInfo(const torch::Tensor &tensor) {
int sizes[MAX_TENSORINFO_DIMS];
int strides[MAX_TENSORINFO_DIMS];
int dims = tensor.dim();
for (int i = 0; i < dims; ++i) {
sizes[i] = tensor.size(i);
strides[i] = tensor.stride(i);
}
return TensorInfo<scalar_t>(tensor.DATA_PTR<scalar_t>(), dims, sizes,
strides);
}
template <typename scalar_t> struct IndexToOffset {
static inline int get(int idx, const TensorInfo<scalar_t> &info) {
int offset = 0;
for (int i = info.dims - 1; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
template <typename scalar_t> struct IndexPtrToOffset {
static inline int get(int idx, const TensorInfo<scalar_t> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
#pragma once
#include <torch/extension.h>
enum ReductionType { SUM, MEAN, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == SUM)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
};
#include <torch/script.h>
#include "dim_apply.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
void scatter_mul(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul", [&] {
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
out_data[idx * out_stride] *= src_data[i * src_stride];
}
});
});
}
void scatter_div(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div", [&] {
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
out_data[idx * out_stride] /= src_data[i * src_stride];
}
});
});
}
void scatter_max(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max", [&] {
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
{
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
if (src_data[i * src_stride] >= out_data[idx * out_stride]) {
out_data[idx * out_stride] = src_data[i * src_stride];
arg_data[idx * arg_stride] = i;
}
}
});
});
}
void scatter_min(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
CHECK_CPU(arg);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min", [&] {
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
{
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
if (src_data[i * src_stride] <= out_data[idx * out_stride]) {
out_data[idx * out_stride] = src_data[i * src_stride];
arg_data[idx * arg_stride] = i;
}
}
});
});
}
static auto registry =
torch::RegisterOperators("torch_scatter_cpu::scatter_mul", &scatter_mul)
.op("torch_scatter_cpu::scatter_div", &scatter_div)
.op("torch_scatter_cpu::scatter_max", &scatter_max)
.op("torch_scatter_cpu::scatter_min", &scatter_min);
#include <torch/script.h>
#include "segment_coo_impl.h"
static auto registry =
torch::RegisterOperators("torch_scatter_cpu::segment_coo", &segment_coo)
.op("torch_scatter_cpu::gather_coo", &gather_coo);
#pragma once
#include <torch/extension.h>
#include "compat.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out, std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= index.dim());
// Broadcasting `index` via `expand`.
auto sizes = index.sizes().vec();
for (int i = 0; i < index.dim(); i++)
sizes[i] = src.size(i);
index = index.expand(sizes);
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
sizes = src.sizes().vec();
sizes[dim] = *index.max().DATA_PTR<int64_t>();
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto B = index.numel() / src.size(dim);
auto E = src.size(dim);
auto K = src.numel() / index.numel();
auto N = out.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t idx, next_idx, row_start;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
for (auto b = 0; b < B; b++) {
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
idx = index_info.data[offset];
for (auto k = 0; k < K; k++)
vals[k] = out_data[b * N * K + k];
row_start = 0;
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[b * E * K + e * K + k], &args[k], e);
if (e == E - 1) {
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(
out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start);
} else {
next_idx = index_info.data[offset + (e + 1) * stride];
assert(idx <= next_idx);
if (idx != next_idx) {
for (auto k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start);
vals[k] = out_data[b * N * K + next_idx * K + k];
}
row_start = e + 1;
}
idx = next_idx;
}
}
}
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) {
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
}
});
});
return std::make_tuple(out, arg_out);
}
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) == index.size(i));
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < src.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = index.size(dim);
out = torch::empty(sizes, src.options());
}
auto B = index.numel() / out.size(dim);
auto E = index.size(dim);
auto K = out.numel() / index.numel();
auto N = src.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t idx, next_idx;
for (auto b = 0; b < B; b++) {
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
idx = index_info.data[offset];
for (auto k = 0; k < K; k++)
vals[k] = src_data[b * N * K + idx * K + k];
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++)
out_data[b * E * K + e * K + k] = vals[k];
if (e < E - 1) {
next_idx = index_info.data[offset + (e + 1) * stride];
CHECK_INPUT(idx <= next_idx);
if (idx != next_idx) {
idx = next_idx;
for (auto k = 0; k < K; k++)
vals[k] = src_data[b * N * K + idx * K + k];
}
}
}
}
});
return out;
}
#include <torch/script.h>
#include "segment_csr_impl.h"
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SegmentSumCSR : public torch::autograd::Function<SegmentSumCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_csr(src, indptr, optional_out, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({indptr});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_csr(grad_out, indptr, grad_in);
return {grad_in, Variable(), Variable()};
}
};
torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentSumCSR::apply(src, indptr, optional_out)[0];
}
static auto registry =
torch::RegisterOperators("torch_scatter_cpu::segment_csr", &segment_csr)
.op("torch_scatter_cpu::gather_csr", &gather_csr)
.op("torch_scatter_cpu::segment_sum_csr", &segment_sum_csr);
#pragma once
#include <torch/extension.h>
#include "compat.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out, std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
} else {
sizes = src.sizes().vec();
sizes[dim] = indptr.size(dim) - 1;
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (auto n = 0; n < N; n++) {
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
for (auto e = row_start; e < row_end; e++) {
CHECK_INPUT(e < E);
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[offset + e * K + k], &args[k], e);
}
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
});
});
return std::make_tuple(out, arg_out);
}
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= indptr.dim());
for (auto i = 0; i < indptr.dim() - 1; i++)
CHECK_INPUT(src.size(i) == indptr.size(i));
auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = *indptr.flatten()[-1].DATA_PTR<int64_t>();
out = torch::empty(sizes, src.options());
}
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
for (int n = 0; n < N; n++) {
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
for (auto k = 0; k < K; k++)
vals[k] = src_data[n * K + k];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (auto e = row_start; e < row_end; e++)
for (auto k = 0; k < K; k++)
out_data[offset + e * K + k] = vals[k];
}
});
return out;
}
#pragma once
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#include "scatter_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t, ReductionType REDUCE>
__global__ void
scatter_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int E, int K, int N, int numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int b = thread_idx / (E * K);
int k = thread_idx % K;
if (thread_idx < numel) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
thread_idx, index_info);
int64_t idx = index_info.data[offset];
Reducer<scalar_t, REDUCE>::atomic_write(out_data + b * N * K + idx * K + k,
src_data[thread_idx]);
}
}
template <typename scalar_t>
__global__ void
scatter_arg_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
const scalar_t *out_data, int64_t *arg_out_data, int E,
int K, int N, int numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int b = thread_idx / (E * K);
int e = (thread_idx / K) % E;
int k = thread_idx % K;
if (thread_idx < numel) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
thread_idx, index_info);
int64_t idx = index_info.data[offset];
if (src_data[thread_idx] == out_data[b * N * K + idx * K + k]) {
arg_out_data[b * N * K + idx * K + k] = e;
}
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
return std::make_tuple(src, optional_out);
CHECK_CUDA(src);
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
CHECK_INPUT(src.dim() == index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) >= index.size(i));
if (dim < 0)
dim = src.dim() + dim;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else {
auto d_size = index.max().data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
sizes[dim] = 1 + *h_size;
}
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto B = 1;
for (auto i = 0; i < dim; i++)
B *= src.size(i);
auto E = src.size(dim);
auto K = src.numel() / (B * E);
auto N = out.size(dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
scatter_kernel<scalar_t, REDUCE>
<<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, K, N, src.numel());
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MIN || REDUCE == MAX)
scatter_arg_kernel<scalar_t>
<<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, K, N,
src.numel());
});
});
return std::make_tuple(out, arg_out);
}
#pragma once
#define ATOMIC(NAME) \
template <typename scalar, size_t size> struct Atomic##NAME##IntegerImpl; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 1> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = (uint32_t *)(address - ((size_t)address & 3)); \
uint32_t old = *address_as_ui; \
uint32_t shift = ((size_t)address & 3) * 8; \
uint32_t sum; \
uint32_t assumed; \
\
do { \
assumed = old; \
sum = OP(val, scalar((old >> shift) & 0xff)); \
old = (old & ~(0x000000ff << shift)) | (sum << shift); \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 2> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = \
(uint32_t *)((char *)address - ((size_t)address & 2)); \
uint32_t old = *address_as_ui; \
uint32_t sum; \
uint32_t newval; \
uint32_t assumed; \
\
do { \
assumed = old; \
sum = OP(val, (size_t)address & 2 ? scalar(old >> 16) \
: scalar(old & 0xffff)); \
newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) \
: (old & 0xffff0000) | sum; \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = (uint32_t *)address; \
uint32_t old = *address_as_ui; \
uint32_t assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ui, assumed, OP(val, (scalar)old)); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 8> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned long long *address_as_ull = (unsigned long long *)address; \
unsigned long long old = *address_as_ull; \
unsigned long long assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ull, assumed, OP(val, (scalar)old)); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
int *address_as_i = (int *)address; \
int old = *address_as_i; \
int assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_i, assumed, \
__float_as_int(OP(val, __int_as_float(assumed)))); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 8> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned long long int *address_as_ull = \
(unsigned long long int *)address; \
unsigned long long int old = *address_as_ull; \
unsigned long long int assumed; \
\
do { \
assumed = old; \
old = atomicCAS( \
address_as_ull, assumed, \
__double_as_longlong(OP(val, __longlong_as_double(assumed)))); \
} while (assumed != old); \
} \
};
#define OP(X, Y) Y + X
ATOMIC(Add)
#undef OP
static inline __device__ void atomAdd(uint8_t *address, uint8_t val) {
AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomAdd(int8_t *address, int8_t val) {
AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomAdd(int16_t *address, int16_t val) {
AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomAdd(int32_t *address, int32_t val) {
atomicAdd(address, val);
}
static inline __device__ void atomAdd(int64_t *address, int64_t val) {
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static inline __device__ void atomAdd(double *address, double val) {
AtomicAddDecimalImpl<double, sizeof(double)>()(address, val);
}
#else
static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val);
}
#endif
#define OP(X, Y) Y *X
ATOMIC(Mul)
#undef OP
static inline __device__ void atomMul(uint8_t *address, uint8_t val) {
AtomicMulIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMul(int8_t *address, int8_t val) {
AtomicMulIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMul(int16_t *address, int16_t val) {
AtomicMulIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMul(int32_t *address, int32_t val) {
AtomicMulIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
}
static inline __device__ void atomMul(int64_t *address, int64_t val) {
AtomicMulIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMul(float *address, float val) {
AtomicMulDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMul(double *address, double val) {
AtomicMulDecimalImpl<double, sizeof(double)>()(address, val);
}
#define OP(X, Y) Y / X
ATOMIC(Div)
#undef OP
static inline __device__ void atomDiv(uint8_t *address, uint8_t val) {
AtomicDivIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomDiv(int8_t *address, int8_t val) {
AtomicDivIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomDiv(int16_t *address, int16_t val) {
AtomicDivIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomDiv(int32_t *address, int32_t val) {
AtomicDivIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
}
static inline __device__ void atomDiv(int64_t *address, int64_t val) {
AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomDiv(float *address, float val) {
AtomicDivDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomDiv(double *address, double val) {
AtomicDivDecimalImpl<double, sizeof(double)>()(address, val);
}
#define OP(X, Y) max(Y, X)
ATOMIC(Max)
#undef OP
static inline __device__ void atomMax(uint8_t *address, uint8_t val) {
AtomicMaxIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMax(int8_t *address, int8_t val) {
AtomicMaxIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMax(int16_t *address, int16_t val) {
AtomicMaxIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMax(int32_t *address, int32_t val) {
atomicMax(address, val);
}
static inline __device__ void atomMax(int64_t *address, int64_t val) {
AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMax(float *address, float val) {
AtomicMaxDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMax(double *address, double val) {
AtomicMaxDecimalImpl<double, sizeof(double)>()(address, val);
}
#define OP(X, Y) min(Y, X)
ATOMIC(Min)
#undef OP
static inline __device__ void atomMin(uint8_t *address, uint8_t val) {
AtomicMinIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMin(int8_t *address, int8_t val) {
AtomicMinIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMin(int16_t *address, int16_t val) {
AtomicMinIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMin(int32_t *address, int32_t val) {
atomicMin(address, val);
}
static inline __device__ void atomMin(int64_t *address, int64_t val) {
AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMin(float *address, float val) {
AtomicMinDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMin(double *address, double val) {
AtomicMinDecimalImpl<double, sizeof(double)>()(address, val);
}
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/script.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt);
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> out_opt);
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return gather_csr_cuda(src, indptr, out_opt);
}
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> out_opt) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return gather_coo_cuda(src, index, out_opt);
}
static auto registry =
torch::RegisterOperators("torch_scatter_cuda::gather_csr", &gather_csr)
.op("torch_scatter_cuda::gather_coo", &gather_coo);
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
#include "compat.cuh"
#include "indptr.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
template <typename scalar_t, int TB>
__global__ void
gather_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx % TB;
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = __ldg(src_data + row_idx);
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) {
out_data[offset + out_idx] = val; // "Mostly" coalesced.
}
}
}
template <typename scalar_t>
__global__ void gather_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = src_data[thread_idx]; // Coalesced.
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int out_idx = row_start; out_idx < row_end; out_idx++) {
out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced.
}
}
}
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt) {
cudaSetDevice(src.get_device());
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
src = src.contiguous();
auto gather_dim = indptr.dim() - 1;
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
"Input mismatch");
torch::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != gather_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
} else {
auto d_gather_size = indptr.flatten()[-1].DATA_PTR<int64_t>();
auto h_gather_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_gather_size, d_gather_size, sizeof(int64_t),
cudaMemcpyDeviceToHost);
auto sizes = src.sizes().vec();
sizes[gather_dim] = *h_gather_size;
out = at::empty(sizes, src.options());
}
auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(gather_dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
if (K == 1) {
gather_csr_kernel<scalar_t, 4><<<BLOCKS(1, 4 * N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E);
} else {
gather_csr_broadcast_kernel<scalar_t>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, K, E);
}
});
return out;
}
template <typename scalar_t>
__global__ void
gather_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t N) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int row = index_info.data[offset];
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N;
scalar_t val = __ldg(src_data + offset + row);
out_data[row_idx] = val;
}
}
template <typename scalar_t>
__global__ void gather_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t K, size_t N) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int col_idx = thread_idx % K;
if (thread_idx < E * K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int row = index_info.data[offset];
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K;
scalar_t val = __ldg(src_data + offset + K * row + col_idx);
out_data[thread_idx] = val;
}
}
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> out_opt) {
cudaSetDevice(src.get_device());
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
for (int i = 0; i < index.dim() - 1; i++)
AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch");
src = src.contiguous();
auto gather_dim = index.dim() - 1;
torch::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(out.size(i) == index.size(i), "Input mismatch");
for (int i = index.dim() + 1; i < src.dim(); i++)
AT_ASSERTM(out.size(i) == src.size(i), "Input mismatch");
} else {
auto sizes = src.sizes().vec();
sizes[gather_dim] = index.size(gather_dim);
out = torch::empty(sizes, src.options());
}
auto E = index.numel();
auto K = out.numel() / E;
auto N = src.size(gather_dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
if (K == 1) {
gather_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, N);
} else {
gather_coo_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(src_data, index_info,
out_data, E, K, N);
}
});
return out;
}
#pragma once
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
template <typename scalar1, typename scalar2, int64_t Dims>
struct IndexToScatterOffsets3 {
static __device__ void
compute(int64_t i, const int64_t dim,
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
int64_t *indexOffset,
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
int64_t *t1Offset,
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
int64_t *t2Offset) {
for (int64_t d = Dims - 1; d >= 0; d--) {
int64_t curDimIndex = i % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
}
i /= index.sizes[d];
}
int64_t indexValue = index.data[*indexOffset];
*t2Offset += indexValue * t2.strides[dim];
}
};
template <typename scalar1, typename scalar2>
struct IndexToScatterOffsets3<scalar1, scalar2, -1> {
static __device__ void
compute(int64_t i, const int64_t dim,
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
int64_t *indexOffset,
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
int64_t *t1Offset,
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
int64_t *t2Offset) {
for (int64_t d = index.dims - 1; d >= 0; d--) {
int64_t curDimIndex = i % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
}
i /= index.sizes[d];
}
int64_t indexValue = index.data[*indexOffset];
*t2Offset += indexValue * t2.strides[dim];
}
};
template <typename scalar1, typename scalar2, typename scalar3, int64_t Dims>
struct IndexToScatterOffsets4 {
static __device__ void
compute(int64_t i, const int64_t dim,
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
int64_t *indexOffset,
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
int64_t *t1Offset,
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
int64_t *t2Offset,
const at::cuda::detail::TensorInfo<scalar3, int64_t> &t3,
int64_t *t3Offset) {
for (int64_t d = Dims - 1; d >= 0; d--) {
int64_t curDimIndex = i % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
*t3Offset += curDimIndex * t3.strides[d];
}
i /= index.sizes[d];
}
int64_t indexValue = index.data[*indexOffset];
*t2Offset += indexValue * t2.strides[dim];
*t3Offset += indexValue * t3.strides[dim];
}
};
template <typename scalar1, typename scalar2, typename scalar3>
struct IndexToScatterOffsets4<scalar1, scalar2, scalar3, -1> {
static __device__ void
compute(int64_t i, const int64_t dim,
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
int64_t *indexOffset,
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
int64_t *t1Offset,
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
int64_t *t2Offset,
const at::cuda::detail::TensorInfo<scalar3, int64_t> &t3,
int64_t *t3Offset) {
for (int64_t d = index.dims - 1; d >= 0; d--) {
int64_t curDimIndex = i % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
*t3Offset += curDimIndex * t3.strides[d];
}
i /= index.sizes[d];
}
int64_t indexValue = index.data[*indexOffset];
*t2Offset += indexValue * t2.strides[dim];
*t3Offset += indexValue * t3.strides[dim];
}
};
#pragma once
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
// We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`.
template <typename scalar_t> struct IndexPtrToOffset {
static inline __host__ __device__ int
get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
#include <torch/script.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
void scatter_mul_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim);
void scatter_div_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim);
void scatter_max_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim);
void scatter_min_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim);
void index_backward_cuda(torch::Tensor grad, torch::Tensor index,
torch::Tensor arg, torch::Tensor out, int64_t dim);
void scatter_mul(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
scatter_mul_cuda(src, index, out, dim);
}
void scatter_div(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
scatter_div_cuda(src, index, out, dim);
}
void scatter_max(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
CHECK_CUDA(arg);
scatter_max_cuda(src, index, out, arg, dim);
}
void scatter_min(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
CHECK_CUDA(arg);
scatter_min_cuda(src, index, out, arg, dim);
}
static auto registry =
torch::RegisterOperators("torch_scatter_cuda::scatter_mul", &scatter_mul)
.op("torch_scatter_cuda::scatter_div", &scatter_div)
.op("torch_scatter_cuda::scatter_max", &scatter_max)
.op("torch_scatter_cuda::scatter_min", &scatter_min);
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
#include "atomics.cuh"
#include "index.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define KERNEL_RUN(NAME, DIMS, N, ...) \
[&] { \
auto stream = at::cuda::getCurrentCUDAStream(); \
switch (DIMS) { \
case 1: \
NAME<scalar_t, 1><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
break; \
case 2: \
NAME<scalar_t, 2><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
break; \
case 3: \
NAME<scalar_t, 3><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
break; \
default: \
NAME<scalar_t, -1><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
} \
}()
template <typename scalar_t, int64_t Dims>
__global__ void
scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
atomMul(&out.data[outOffset], src.data[srcOffset]);
}
}
void scatter_mul_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] {
KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
});
}
template <typename scalar_t, int64_t Dims>
__global__ void
scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
atomDiv(&out.data[outOffset], src.data[srcOffset]);
}
}
void scatter_div_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] {
KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
});
}
template <typename scalar_t, int64_t Dims>
__global__ void arg_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
at::cuda::detail::TensorInfo<int64_t, int64_t> arg,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0, argOffset = 0;
IndexToScatterOffsets4<scalar_t, scalar_t, int64_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset, arg,
&argOffset);
if (src.data[srcOffset] == out.data[outOffset]) {
arg.data[argOffset] = (srcOffset / src.strides[dim]) % src.sizes[dim];
}
}
}
template <typename scalar_t, int64_t Dims>
__global__ void
scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
atomMax(&out.data[outOffset], src.data[srcOffset]);
}
}
void scatter_max_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
KERNEL_RUN(scatter_max_kernel, index.dim(), index.numel(), src_info,
index_info, out_info, dim);
KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info,
out_info, at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
dim);
});
}
template <typename scalar_t, int64_t Dims>
__global__ void
scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
atomMin(&out.data[outOffset], src.data[srcOffset]);
}
}
void scatter_min_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
KERNEL_RUN(scatter_min_kernel, index.dim(), index.numel(), src_info,
index_info, out_info, dim);
KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info,
out_info, at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
dim);
});
}
#include <torch/script.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt, std::string reduce);
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
std::string reduce);
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt, std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return segment_csr_cuda(src, indptr, out_opt, reduce);
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo(torch::Tensor src, torch::Tensor index, torch::Tensor out,
std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
return segment_coo_cuda(src, index, out, reduce);
}
static auto registry =
torch::RegisterOperators("torch_scatter_cuda::segment_csr", &segment_csr)
.op("torch_scatter_cuda::segment_coo", &segment_coo);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment