Unverified Commit c01f9bae authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #105 from rusty1s/traceable

[WIP] tracebale functions
parents 2520670a 02a47c46
#pragma once #pragma once
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
// We need our own `IndexToOffset` implementation since we do not want to // We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`. // access the last element of the `indexptr`.
......
#pragma once
#include <limits>
#include <map>
#include "atomics.cuh"
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
{"div", DIV}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
const ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
const ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline __host__ __device__ void update(scalar_t *val,
scalar_t new_val) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
}
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
static inline __device__ void atomic_write(scalar_t *address, scalar_t val) {
if (REDUCE == SUM || REDUCE == MEAN)
atomAdd(address, val);
else if (REDUCE == MUL)
atomMul(address, val);
else if (REDUCE == DIV)
atomDiv(address, val);
else if (REDUCE == MIN)
atomMin(address, val);
else if (REDUCE == MAX)
atomMax(address, val);
}
};
#include "scatter_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t, ReductionType REDUCE>
__global__ void
scatter_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int E, int K, int N, int numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int b = thread_idx / (E * K);
int k = thread_idx % K;
if (thread_idx < numel) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
thread_idx, index_info);
int64_t idx = index_info.data[offset];
Reducer<scalar_t, REDUCE>::atomic_write(out_data + b * N * K + idx * K + k,
src_data[thread_idx]);
}
}
template <typename scalar_t>
__global__ void
scatter_arg_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
const scalar_t *out_data, int64_t *arg_out_data, int E,
int K, int N, int numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int b = thread_idx / (E * K);
int e = (thread_idx / K) % E;
int k = thread_idx % K;
if (thread_idx < numel) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
thread_idx, index_info);
int64_t idx = index_info.data[offset];
if (src_data[thread_idx] == out_data[b * N * K + idx * K + k]) {
arg_out_data[b * N * K + idx * K + k] = e;
}
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
CHECK_INPUT(src.dim() == index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) >= index.size(i));
if (dim < 0)
dim = src.dim() + dim;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else {
auto d_size = index.max().data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
sizes[dim] = 1 + *h_size;
}
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto B = 1;
for (auto i = 0; i < dim; i++)
B *= src.size(i);
auto E = src.size(dim);
auto K = src.numel() / (B * E);
auto N = out.size(dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
scatter_kernel<scalar_t, REDUCE>
<<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, K, N, src.numel());
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MIN || REDUCE == MAX)
scatter_arg_kernel<scalar_t>
<<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, K, N,
src.numel());
});
});
return std::make_tuple(out, arg_out);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
#include "segment_coo_cuda.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
#include "atomics.cuh" #include "reducer.cuh"
#include "compat.cuh" #include "utils.cuh"
#include "indptr.cuh"
#define THREADS 256 #define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS #define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff #define FULL_MASK 0xffffffff
enum ReductionType { SUM, MEAN, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
return std::numeric_limits<scalar_t>::lowest();
} else {
return (scalar_t)0;
}
}
static inline __host__ __device__ void update(scalar_t *val,
scalar_t new_val) {
if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
}
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == SUM) {
*address = val;
} else if (REDUCE == MEAN) {
*address = val / (scalar_t)max(count, 1);
} else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else {
*address = (scalar_t)0;
}
}
}
static inline __device__ void atomic_write(scalar_t *address, scalar_t val) {
if (REDUCE == SUM || REDUCE == MEAN) {
atomAdd(address, val);
} else if (REDUCE == MIN && val < *address) {
atomMin(address, val);
} else if (REDUCE == MAX && val > *address) {
atomMax(address, val);
}
}
};
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void
segment_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, int64_t *arg_out_data, size_t N,
size_t E) {
// Each warp processes exactly `32/TB` rows and aggregates all row values
// via a parallel reduction.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset);
int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int64_t src_idx = row_start + lane_idx; src_idx < row_end;
src_idx += TB) {
Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg,
src_idx);
}
#pragma unroll
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
}
if (lane_idx == 0) {
Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
arg_out_data + row_idx, arg,
row_end - row_start);
}
}
}
template <typename scalar_t, ReductionType REDUCE>
__global__ void segment_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
// Each thread processes exactly one row. It turned out that is more
// efficient than using shared memory due to avoiding synchronization
// barriers.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset);
int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int64_t src_idx = row_start; src_idx < row_end; src_idx++) {
Reducer<scalar_t, REDUCE>::update(
&val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx);
}
Reducer<scalar_t, REDUCE>::write(out_data + thread_idx, val,
arg_out_data + thread_idx, arg,
row_end - row_start);
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt, std::string reduce) {
cudaSetDevice(src.get_device());
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
// Broadcasting `indptr` via `expand`.
auto sizes = indptr.sizes().vec();
for (int i = 0; i < indptr.dim() - 1; i++) {
sizes[i] = src.size(i);
}
indptr = indptr.expand(sizes);
src = src.contiguous();
auto reduce_dim = indptr.dim() - 1;
torch::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1,
"Input mismatch");
} else {
sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(reduce_dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(reduce_dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (K == 1) {
segment_csr_kernel<scalar_t, REDUCE, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, E);
} else {
segment_csr_broadcast_kernel<scalar_t, REDUCE>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, K, E);
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL> template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__ void __global__ void
segment_coo_kernel(const scalar_t *src_data, segment_coo_kernel(const scalar_t *src_data,
...@@ -385,110 +150,220 @@ __global__ void segment_coo_arg_broadcast_kernel( ...@@ -385,110 +150,220 @@ __global__ void segment_coo_arg_broadcast_kernel(
} }
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, segment_coo_cuda(torch::Tensor src, torch::Tensor index,
std::string reduce) { torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch"); CHECK_INPUT(src.dim() >= index.dim());
// Broadcasting `index` via `expand`.
auto sizes = index.sizes().vec(); auto sizes = index.sizes().vec();
for (int i = 0; i < index.dim(); i++) { for (int i = 0; i < index.dim(); i++) {
sizes[i] = src.size(i); sizes[i] = src.size(i);
} }
index = index.expand(sizes); index = index.expand(sizes);
auto dim = index.dim() - 1;
src = src.contiguous(); src = src.contiguous();
out = out.contiguous();
auto reduce_dim = index.dim() - 1;
for (int i = 0; i < out.dim(); i++) torch::Tensor out;
if (i != reduce_dim) if (optional_out.has_value()) {
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else {
auto d_size = index.max().data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
sizes[dim] = 1 + *h_size;
}
out = torch::zeros(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt; torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(reduce_dim), index.options()); arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
} }
auto E = index.numel(); auto E = index.numel();
auto E_2 = index.size(reduce_dim); auto E_2 = index.size(dim);
auto E_1 = index.numel() / E_2; auto E_1 = index.numel() / E_2;
auto K = src.numel() / E; auto K = src.numel() / E;
auto N = out.size(reduce_dim); auto N = out.size(dim);
auto avg_len = (float)E_2 / (float)N; auto avg_len = (float)E_2 / (float)N;
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (K == 1) { if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
if (K == 1)
segment_coo_kernel<scalar_t, REDUCE, true> segment_coo_kernel<scalar_t, REDUCE, true>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info, <<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
out_data, E, N); out_data, E, N);
} else if (avg_len <= 8) { else if (avg_len <= 8)
segment_coo_broadcast_kernel<scalar_t, REDUCE, 4> segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
<<<dim3((E_1 * ((E_2 + 3) / 4) + 7) / 8, (K + 31) / 32), <<<dim3((E_1 * ((E_2 + 3) / 4) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K, dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N); N);
} else if (avg_len <= 16) { else if (avg_len <= 16)
segment_coo_broadcast_kernel<scalar_t, REDUCE, 8> segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
<<<dim3((E_1 * ((E_2 + 7) / 8) + 7) / 8, (K + 31) / 32), <<<dim3((E_1 * ((E_2 + 7) / 8) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K, dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N); N);
} else if (avg_len <= 32) { else if (avg_len <= 32)
segment_coo_broadcast_kernel<scalar_t, REDUCE, 16> segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
<<<dim3((E_1 * ((E_2 + 15) / 16) + 7) / 8, (K + 31) / 32), <<<dim3((E_1 * ((E_2 + 15) / 16) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K, dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N); N);
} else { else
segment_coo_broadcast_kernel<scalar_t, REDUCE, 32> segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
<<<dim3((E_1 * ((E_2 + 31) / 32) + 7) / 8, (K + 31) / 32), <<<dim3((E_1 * ((E_2 + 31) / 32) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K, dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N); N);
}
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MIN || REDUCE == MAX) { if (REDUCE == MIN || REDUCE == MAX) {
if (K == 1) { if (K == 1)
segment_coo_arg_kernel<scalar_t> segment_coo_arg_kernel<scalar_t>
<<<BLOCKS(1, E), THREADS, 0, stream>>>( <<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, N); src_data, index_info, out_data, arg_out_data, E, N);
} else { else
segment_coo_arg_broadcast_kernel<scalar_t> segment_coo_arg_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>( <<<BLOCKS(1, E * K), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, K, N); src_data, index_info, out_data, arg_out_data, E, K, N);
} }
if (REDUCE == MEAN) {
auto sizes = index.sizes().vec();
sizes[dim] = out.size(dim);
auto count = torch::zeros(sizes, out.options());
auto count_data = count.data_ptr<scalar_t>();
segment_coo_kernel<scalar_t, SUM, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, E, N);
arg_out = count;
for (int i = dim + 1; i < out.dim(); i++)
count = count.unsqueeze(-1);
out.div_(count.clamp_(1));
} }
}); });
}); });
if (reduce2REDUCE.at(reduce) == MEAN) { return std::make_tuple(out, arg_out);
auto sizes = index.sizes().vec(); }
sizes[reduce_dim] = out.size(reduce_dim);
auto count = torch::zeros(sizes, out.options());
AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] { template <typename scalar_t>
auto count_data = count.DATA_PTR<scalar_t>(); __global__ void
segment_coo_kernel<scalar_t, SUM, false> gather_coo_kernel(const scalar_t *src_data,
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info, const at::cuda::detail::TensorInfo<int64_t, int> index_info,
count_data, E, N); scalar_t *out_data, size_t E, size_t N) {
});
count.clamp_(1); int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
arg_out = count;
for (int i = reduce_dim + 1; i < out.dim(); i++) { if (row_idx < E) {
count = count.unsqueeze(-1); int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
} row_idx, index_info);
int row = index_info.data[offset];
out.div_(count); offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N;
scalar_t val = __ldg(src_data + offset + row);
out_data[row_idx] = val;
} }
}
return std::make_tuple(out, arg_out); template <typename scalar_t>
__global__ void gather_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t K, size_t N) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int col_idx = thread_idx % K;
if (thread_idx < E * K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int row = index_info.data[offset];
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K;
scalar_t val = __ldg(src_data + offset + K * row + col_idx);
out_data[thread_idx] = val;
}
}
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= index.dim());
auto sizes = index.sizes().vec();
for (auto i = 0; i < index.dim() - 1; i++)
sizes[i] = src.size(i);
index = index.expand(sizes);
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < src.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(index.size(dim) == out.size(dim));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = index.size(dim);
out = torch::empty(sizes, src.options());
}
auto E = index.numel();
auto K = out.numel() / E;
auto N = src.size(dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
if (K == 1)
gather_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, N);
else
gather_coo_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(src_data, index_info,
out_data, E, K, N);
});
return out;
} }
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out);
#include "segment_csr_cuda.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
#include "compat.cuh" #include "index_info.cuh"
#include "indptr.cuh" #include "reducer.cuh"
#include "utils.cuh"
#define THREADS 256 #define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS #define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
template <typename scalar_t, int TB> template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void __global__ void
gather_csr_kernel(const scalar_t *src_data, segment_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info, const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t E) { scalar_t *out_data, int64_t *arg_out_data, size_t N,
size_t E) {
// Each warp processes exactly `32/TB` rows and aggregates all row values
// via a parallel reduction.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB; int row_idx = thread_idx / TB;
int lane_idx = thread_idx % TB; int lane_idx = thread_idx & (TB - 1);
if (row_idx < N) { if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info); int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset); int64_t row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset + int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]); indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = __ldg(src_data + row_idx);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) { for (int64_t src_idx = row_start + lane_idx; src_idx < row_end;
out_data[offset + out_idx] = val; // "Mostly" coalesced. src_idx += TB) {
Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg,
src_idx);
}
#pragma unroll
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
}
if (lane_idx == 0) {
Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
arg_out_data + row_idx, arg,
row_end - row_start);
} }
} }
} }
template <typename scalar_t> template <typename scalar_t, ReductionType REDUCE>
__global__ void gather_csr_broadcast_kernel( __global__ void segment_csr_broadcast_kernel(
const scalar_t *src_data, const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info, const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) { scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
// Each thread processes exactly one row. It turned out that is more
// efficient than using shared memory due to avoiding synchronization
// barriers.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K; int row_idx = thread_idx / K;
...@@ -45,157 +75,193 @@ __global__ void gather_csr_broadcast_kernel( ...@@ -45,157 +75,193 @@ __global__ void gather_csr_broadcast_kernel(
if (thread_idx < N * K) { if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info); int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset); int64_t row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset + int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]); indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = src_data[thread_idx]; // Coalesced. scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int out_idx = row_start; out_idx < row_end; out_idx++) { for (int64_t src_idx = row_start; src_idx < row_end; src_idx++) {
out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced. Reducer<scalar_t, REDUCE>::update(
&val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx);
} }
Reducer<scalar_t, REDUCE>::write(out_data + thread_idx, val,
arg_out_data + thread_idx, arg,
row_end - row_start);
} }
} }
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
torch::optional<torch::Tensor> out_opt) { segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
for (int i = 0; i < indptr.dim() - 1; i++) CHECK_INPUT(src.dim() >= indptr.dim());
AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
src = src.contiguous(); src = src.contiguous();
auto gather_dim = indptr.dim() - 1;
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
"Input mismatch");
torch::Tensor out; torch::Tensor out;
if (out_opt.has_value()) { if (optional_out.has_value()) {
out = out_opt.value().contiguous(); out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++) for (int i = 0; i < out.dim(); i++)
if (i != gather_dim) if (i != dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
} else { } else {
auto d_gather_size = indptr.flatten()[-1].DATA_PTR<int64_t>(); sizes = src.sizes().vec();
auto h_gather_size = (int64_t *)malloc(sizeof(int64_t)); sizes[dim] = indptr.size(dim) - 1;
cudaMemcpy(h_gather_size, d_gather_size, sizeof(int64_t), out = torch::empty(sizes, src.options());
cudaMemcpyDeviceToHost); }
auto sizes = src.sizes().vec(); torch::optional<torch::Tensor> arg_out = torch::nullopt;
sizes[gather_dim] = *h_gather_size; int64_t *arg_out_data = nullptr;
out = at::empty(sizes, src.options()); if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
} }
auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1)); auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N; auto K = out.numel() / N;
auto E = out.size(gather_dim); auto E = src.size(dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr); auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
if (K == 1) { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
gather_csr_kernel<scalar_t, 4><<<BLOCKS(1, 4 * N), THREADS, 0, stream>>>( if (K == 1) {
src_data, indptr_info, out_data, N, E); segment_csr_kernel<scalar_t, REDUCE, 1>
} else { <<<BLOCKS(32, N), THREADS, 0, stream>>>(
gather_csr_broadcast_kernel<scalar_t> src_data, indptr_info, out_data, arg_out_data, N, E);
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info, } else {
out_data, N, K, E); segment_csr_broadcast_kernel<scalar_t, REDUCE>
} <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, K, E);
}
});
}); });
return out; return std::make_tuple(out, arg_out);
} }
template <typename scalar_t> template <typename scalar_t, int TB>
__global__ void __global__ void
gather_coo_kernel(const scalar_t *src_data, gather_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info, const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t E, size_t N) { scalar_t *out_data, size_t N, size_t E) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx < E) { int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( int row_idx = thread_idx / TB;
row_idx, index_info); int lane_idx = thread_idx % TB;
int row = index_info.data[offset];
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N; if (row_idx < N) {
scalar_t val = __ldg(src_data + offset + row); int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = __ldg(src_data + row_idx);
out_data[row_idx] = val; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) {
out_data[offset + out_idx] = val; // "Mostly" coalesced.
}
} }
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void gather_coo_broadcast_kernel( __global__ void gather_csr_broadcast_kernel(
const scalar_t *src_data, const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info, const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t E, size_t K, size_t N) { scalar_t *out_data, size_t N, size_t K, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K; int row_idx = thread_idx / K;
int col_idx = thread_idx % K; int lane_idx = thread_idx % K;
if (thread_idx < E * K) { if (thread_idx < N * K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
row_idx, index_info); int row_start = __ldg(indptr_info.data + offset);
int row = index_info.data[offset]; int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K; scalar_t val = src_data[thread_idx]; // Coalesced.
scalar_t val = __ldg(src_data + offset + K * row + col_idx);
out_data[thread_idx] = val; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int out_idx = row_start; out_idx < row_end; out_idx++) {
out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced.
}
} }
} }
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt) { torch::optional<torch::Tensor> optional_out) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch"); CHECK_INPUT(src.dim() >= indptr.dim());
for (int i = 0; i < index.dim() - 1; i++)
AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch"); auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous(); src = src.contiguous();
auto gather_dim = index.dim() - 1;
torch::Tensor out; torch::Tensor out;
if (out_opt.has_value()) { if (optional_out.has_value()) {
out = out_opt.value().contiguous(); out = optional_out.value().contiguous();
for (int i = 0; i < index.dim(); i++) for (auto i = 0; i < out.dim(); i++)
AT_ASSERTM(out.size(i) == index.size(i), "Input mismatch"); if (i != dim)
for (int i = index.dim() + 1; i < src.dim(); i++) CHECK_INPUT(src.size(i) == out.size(i));
AT_ASSERTM(out.size(i) == src.size(i), "Input mismatch");
} else { } else {
auto d_size = indptr.flatten()[-1].data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
sizes[gather_dim] = index.size(gather_dim); sizes[dim] = *h_size;
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
auto E = index.numel(); auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / E; auto K = src.numel() / N;
auto N = src.size(gather_dim); auto E = out.size(dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index); auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
if (K == 1) { if (K == 1)
gather_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>( gather_csr_kernel<scalar_t, 4><<<BLOCKS(1, 4 * N), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, N); src_data, indptr_info, out_data, N, E);
} else { else
gather_coo_broadcast_kernel<scalar_t> gather_csr_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(src_data, index_info, <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, E, K, N); out_data, N, K, E);
}
}); });
return out; return out;
......
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce);
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#include <torch/script.h>
#include "cpu/scatter_cpu.h"
#ifdef WITH_CUDA
#include "cuda/scatter_cuda.h"
#endif
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
if (dim < 0)
dim = other.dim() + dim;
if (src.dim() == 1)
for (auto i = 0; i < dim; i++)
src = src.unsqueeze(0);
for (auto i = src.dim(); i < other.dim(); i++)
src = src.unsqueeze(-1);
src = src.expand(other.sizes().vec());
return src;
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return scatter_cuda(src, index, dim, optional_out, dim_size, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return scatter_cpu(src, index, dim, optional_out, dim_size, reduce);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class ScatterSum : public torch::autograd::Function<ScatterSum> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::gather(grad_out, dim, index, false);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMean : public torch::autograd::Function<ScatterMean> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
auto old_index = index;
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
auto ones = torch::ones(old_index.sizes(), src.options());
result = scatter_fw(ones, old_index,
old_index.dim() <= dim ? old_index.dim() - 1 : dim,
torch::nullopt, out.size(dim), "sum");
auto count = std::get<0>(result);
count.clamp_(1);
count = broadcast(count, out, dim);
out.div_(count);
ctx->save_for_backward({index, count});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto count = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
count = torch::gather(count, dim, index, false);
auto grad_in = torch::gather(grad_out, dim, index, false);
grad_in.div_(count);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMin : public torch::autograd::Function<ScatterMin> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMax : public torch::autograd::Function<ScatterMax> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
}
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
}
std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMax::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators()
.op("torch_scatter::scatter_sum", &scatter_sum)
.op("torch_scatter::scatter_mean", &scatter_mean)
.op("torch_scatter::scatter_min", &scatter_min)
.op("torch_scatter::scatter_max", &scatter_max);
#include <torch/script.h>
#include "cpu/segment_coo_cpu.h"
#ifdef WITH_CUDA
#include "cuda/segment_coo_cuda.h"
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_fw(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return segment_coo_cuda(src, index, optional_out, dim_size, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return segment_coo_cpu(src, index, optional_out, dim_size, reduce);
}
}
torch::Tensor gather_coo_fw(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return gather_coo_cuda(src, index, optional_out);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return gather_coo_cpu(src, index, optional_out);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SegmentSumCOO : public torch::autograd::Function<SegmentSumCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_coo_fw(grad_out, index, grad_in);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class SegmentMeanCOO : public torch::autograd::Function<SegmentMeanCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "mean");
auto out = std::get<0>(result);
auto count = std::get<1>(result).value();
ctx->save_for_backward({index, count});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto count = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_coo_fw(grad_out, index, grad_in);
count = gather_coo_fw(count, index, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - index.dim(); i++)
count = count.unsqueeze(-1);
grad_in.div_(count);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class SegmentMinCOO : public torch::autograd::Function<SegmentMinCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
src_shape[index.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(index.dim() - 1, 0, src_shape[index.dim() - 1] - 1);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class SegmentMaxCOO : public torch::autograd::Function<SegmentMaxCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
src_shape[index.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(index.dim() - 1, 0, src_shape[index.dim() - 1] - 1);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class GatherCOO : public torch::autograd::Function<GatherCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = gather_coo_fw(src, index, optional_out);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::zeros(src_shape, grad_out.options());
segment_coo_fw(grad_out, index, grad_in, torch::nullopt, "sum");
return {grad_in, Variable(), Variable()};
}
};
torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return SegmentSumCOO::apply(src, index, optional_out, dim_size)[0];
}
torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return SegmentMeanCOO::apply(src, index, optional_out, dim_size)[0];
}
std::tuple<torch::Tensor, torch::Tensor>
segment_min_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = SegmentMinCOO::apply(src, index, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
segment_max_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = SegmentMaxCOO::apply(src, index, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
return GatherCOO::apply(src, index, optional_out)[0];
}
static auto registry =
torch::RegisterOperators()
.op("torch_scatter::segment_sum_coo", &segment_sum_coo)
.op("torch_scatter::segment_mean_coo", &segment_mean_coo)
.op("torch_scatter::segment_min_coo", &segment_min_coo)
.op("torch_scatter::segment_max_coo", &segment_max_coo)
.op("torch_scatter::gather_coo", &gather_coo);
#include <torch/script.h>
#include "cpu/segment_csr_cpu.h"
#ifdef WITH_CUDA
#include "cuda/segment_csr_cuda.h"
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_fw(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return segment_csr_cuda(src, indptr, optional_out, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return segment_csr_cpu(src, indptr, optional_out, reduce);
}
}
torch::Tensor gather_csr_fw(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return gather_csr_cuda(src, indptr, optional_out);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return gather_csr_cpu(src, indptr, optional_out);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SegmentSumCSR : public torch::autograd::Function<SegmentSumCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = std::get<0>(segment_csr_fw(src, indptr, optional_out, "sum"));
ctx->save_for_backward({indptr});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_csr_fw(grad_out, indptr, grad_in);
return {grad_in, Variable(), Variable()};
}
};
class SegmentMeanCSR : public torch::autograd::Function<SegmentMeanCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = std::get<0>(segment_csr_fw(src, indptr, optional_out, "mean"));
ctx->save_for_backward({indptr});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_csr_fw(grad_out, indptr, grad_in);
auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1);
auto indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1);
auto count = (indptr2 - indptr1).to(grad_in.options());
count = gather_csr_fw(count, indptr, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++)
count = count.unsqueeze(-1);
grad_in.div_(count);
return {grad_in, Variable(), Variable()};
}
};
class SegmentMinCSR : public torch::autograd::Function<SegmentMinCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_csr_fw(src, indptr, optional_out, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({indptr, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
src_shape[indptr.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(indptr.dim() - 1, 0, src_shape[indptr.dim() - 1] - 1);
return {grad_in, Variable(), Variable()};
}
};
class SegmentMaxCSR : public torch::autograd::Function<SegmentMaxCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_csr_fw(src, indptr, optional_out, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({indptr, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
src_shape[indptr.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(indptr.dim() - 1, 0, src_shape[indptr.dim() - 1] - 1);
return {grad_in, Variable(), Variable()};
}
};
class GatherCSR : public torch::autograd::Function<GatherCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = gather_csr_fw(src, indptr, optional_out);
ctx->save_for_backward({indptr});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::empty(src_shape, grad_out.options());
segment_csr_fw(grad_out, indptr, grad_in, "sum");
return {grad_in, Variable(), Variable()};
}
};
torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentSumCSR::apply(src, indptr, optional_out)[0];
}
torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentMeanCSR::apply(src, indptr, optional_out)[0];
}
std::tuple<torch::Tensor, torch::Tensor>
segment_min_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
auto result = SegmentMinCSR::apply(src, indptr, optional_out);
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
segment_max_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
auto result = SegmentMaxCSR::apply(src, indptr, optional_out);
return std::make_tuple(result[0], result[1]);
}
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return GatherCSR::apply(src, indptr, optional_out)[0];
}
static auto registry =
torch::RegisterOperators()
.op("torch_scatter::segment_sum_csr", &segment_sum_csr)
.op("torch_scatter::segment_mean_csr", &segment_mean_csr)
.op("torch_scatter::segment_min_csr", &segment_min_csr)
.op("torch_scatter::segment_max_csr", &segment_max_csr)
.op("torch_scatter::gather_csr", &gather_csr);
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/script.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt);
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> out_opt);
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return gather_csr_cuda(src, indptr, out_opt);
}
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> out_opt) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return gather_coo_cuda(src, index, out_opt);
}
static auto registry =
torch::RegisterOperators("torch_scatter_cuda::gather_csr", &gather_csr)
.op("torch_scatter_cuda::gather_coo", &gather_coo);
#pragma once
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
template <typename scalar1, typename scalar2, int64_t Dims>
struct IndexToScatterOffsets3 {
static __device__ void
compute(int64_t i, const int64_t dim,
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
int64_t *indexOffset,
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
int64_t *t1Offset,
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
int64_t *t2Offset) {
for (int64_t d = Dims - 1; d >= 0; d--) {
int64_t curDimIndex = i % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
}
i /= index.sizes[d];
}
int64_t indexValue = index.data[*indexOffset];
*t2Offset += indexValue * t2.strides[dim];
}
};
template <typename scalar1, typename scalar2>
struct IndexToScatterOffsets3<scalar1, scalar2, -1> {
static __device__ void
compute(int64_t i, const int64_t dim,
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
int64_t *indexOffset,
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
int64_t *t1Offset,
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
int64_t *t2Offset) {
for (int64_t d = index.dims - 1; d >= 0; d--) {
int64_t curDimIndex = i % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
}
i /= index.sizes[d];
}
int64_t indexValue = index.data[*indexOffset];
*t2Offset += indexValue * t2.strides[dim];
}
};
template <typename scalar1, typename scalar2, typename scalar3, int64_t Dims>
struct IndexToScatterOffsets4 {
static __device__ void
compute(int64_t i, const int64_t dim,
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
int64_t *indexOffset,
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
int64_t *t1Offset,
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
int64_t *t2Offset,
const at::cuda::detail::TensorInfo<scalar3, int64_t> &t3,
int64_t *t3Offset) {
for (int64_t d = Dims - 1; d >= 0; d--) {
int64_t curDimIndex = i % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
*t3Offset += curDimIndex * t3.strides[d];
}
i /= index.sizes[d];
}
int64_t indexValue = index.data[*indexOffset];
*t2Offset += indexValue * t2.strides[dim];
*t3Offset += indexValue * t3.strides[dim];
}
};
template <typename scalar1, typename scalar2, typename scalar3>
struct IndexToScatterOffsets4<scalar1, scalar2, scalar3, -1> {
static __device__ void
compute(int64_t i, const int64_t dim,
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
int64_t *indexOffset,
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
int64_t *t1Offset,
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
int64_t *t2Offset,
const at::cuda::detail::TensorInfo<scalar3, int64_t> &t3,
int64_t *t3Offset) {
for (int64_t d = index.dims - 1; d >= 0; d--) {
int64_t curDimIndex = i % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
*t3Offset += curDimIndex * t3.strides[d];
}
i /= index.sizes[d];
}
int64_t indexValue = index.data[*indexOffset];
*t2Offset += indexValue * t2.strides[dim];
*t3Offset += indexValue * t3.strides[dim];
}
};
#include <torch/script.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
void scatter_mul_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim);
void scatter_div_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim);
void scatter_max_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim);
void scatter_min_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim);
void index_backward_cuda(torch::Tensor grad, torch::Tensor index,
torch::Tensor arg, torch::Tensor out, int64_t dim);
void scatter_mul(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
scatter_mul_cuda(src, index, out, dim);
}
void scatter_div(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
scatter_div_cuda(src, index, out, dim);
}
void scatter_max(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
CHECK_CUDA(arg);
scatter_max_cuda(src, index, out, arg, dim);
}
void scatter_min(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
CHECK_CUDA(arg);
scatter_min_cuda(src, index, out, arg, dim);
}
static auto registry =
torch::RegisterOperators("torch_scatter_cuda::scatter_mul", &scatter_mul)
.op("torch_scatter_cuda::scatter_div", &scatter_div)
.op("torch_scatter_cuda::scatter_max", &scatter_max)
.op("torch_scatter_cuda::scatter_min", &scatter_min);
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
#include "atomics.cuh"
#include "index.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define KERNEL_RUN(NAME, DIMS, N, ...) \
[&] { \
auto stream = at::cuda::getCurrentCUDAStream(); \
switch (DIMS) { \
case 1: \
NAME<scalar_t, 1><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
break; \
case 2: \
NAME<scalar_t, 2><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
break; \
case 3: \
NAME<scalar_t, 3><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
break; \
default: \
NAME<scalar_t, -1><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
} \
}()
template <typename scalar_t, int64_t Dims>
__global__ void
scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
atomMul(&out.data[outOffset], src.data[srcOffset]);
}
}
void scatter_mul_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] {
KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
});
}
template <typename scalar_t, int64_t Dims>
__global__ void
scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
atomDiv(&out.data[outOffset], src.data[srcOffset]);
}
}
void scatter_div_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] {
KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
});
}
template <typename scalar_t, int64_t Dims>
__global__ void arg_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
at::cuda::detail::TensorInfo<int64_t, int64_t> arg,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0, argOffset = 0;
IndexToScatterOffsets4<scalar_t, scalar_t, int64_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset, arg,
&argOffset);
if (src.data[srcOffset] == out.data[outOffset]) {
arg.data[argOffset] = (srcOffset / src.strides[dim]) % src.sizes[dim];
}
}
}
template <typename scalar_t, int64_t Dims>
__global__ void
scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
atomMax(&out.data[outOffset], src.data[srcOffset]);
}
}
void scatter_max_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
KERNEL_RUN(scatter_max_kernel, index.dim(), index.numel(), src_info,
index_info, out_info, dim);
KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info,
out_info, at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
dim);
});
}
template <typename scalar_t, int64_t Dims>
__global__ void
scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
atomMin(&out.data[outOffset], src.data[srcOffset]);
}
}
void scatter_min_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
KERNEL_RUN(scatter_min_kernel, index.dim(), index.numel(), src_info,
index_info, out_info, dim);
KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info,
out_info, at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
dim);
});
}
#include <torch/script.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt, std::string reduce);
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
std::string reduce);
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt, std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return segment_csr_cuda(src, indptr, out_opt, reduce);
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo(torch::Tensor src, torch::Tensor index, torch::Tensor out,
std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
return segment_coo_cuda(src, index, out, reduce);
}
static auto registry =
torch::RegisterOperators("torch_scatter_cuda::segment_csr", &segment_csr)
.op("torch_scatter_cuda::segment_coo", &segment_coo);
...@@ -3,3 +3,4 @@ numpy ...@@ -3,3 +3,4 @@ numpy
torch_nightly torch_nightly
sphinx sphinx
sphinx_rtd_theme sphinx_rtd_theme
sphinx-autodoc-typehints
Scatter Softmax
===============
.. automodule:: torch_scatter.composite
:noindex:
.. autofunction:: scatter_softmax
.. autofunction:: scatter_log_softmax
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment