Commit 0c887ffc authored by rusty1s's avatar rusty1s
Browse files

segment/gather csr done

parent 26a9e988
Copyright (c) 2019 Matthias Fey <matthias.fey@tu-dortmund.de> Copyright (c) 2020 Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
......
#pragma once
#include <torch/extension.h>
#define MAX_TENSORINFO_DIMS 25
template <typename scalar_t> struct TensorInfo {
TensorInfo(scalar_t *p, int dim, int sz[MAX_TENSORINFO_DIMS],
int st[MAX_TENSORINFO_DIMS]) {
data = p;
dims = dim;
AT_ASSERT(dims < MAX_TENSORINFO_DIMS);
for (int i = 0; i < dim; ++i) {
sizes[i] = sz[i];
strides[i] = st[i];
}
}
scalar_t *data;
int dims;
int sizes[MAX_TENSORINFO_DIMS];
int strides[MAX_TENSORINFO_DIMS];
};
template <typename scalar_t>
TensorInfo<scalar_t> getTensorInfo(const torch::Tensor &tensor) {
int sizes[MAX_TENSORINFO_DIMS];
int strides[MAX_TENSORINFO_DIMS];
int dims = tensor.dim();
for (int i = 0; i < dims; ++i) {
sizes[i] = tensor.size(i);
strides[i] = tensor.stride(i);
}
return TensorInfo<scalar_t>(tensor.data_ptr<scalar_t>(), dims, sizes,
strides);
}
template <typename scalar_t> struct IndexToOffset {
static inline int get(int idx, const TensorInfo<scalar_t> &info) {
int offset = 0;
for (int i = info.dims - 1; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
template <typename scalar_t> struct IndexPtrToOffset {
static inline int get(int idx, const TensorInfo<scalar_t> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
#pragma once
#include <limits>
#include <map>
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
{"div", DIV}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
const ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
const ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
};
#include "segment_csr_cpu.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
} else {
sizes = src.sizes().vec();
sizes[dim] = indptr.size(dim) - 1;
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (auto n = 0; n < N; n++) {
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
for (auto e = row_start; e < row_end; e++) {
CHECK_INPUT(e < E);
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[offset + e * K + k], &args[k], e);
}
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
});
});
return std::make_tuple(out, arg_out);
}
torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>();
out = torch::empty(sizes, src.options());
}
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
for (int n = 0; n < N; n++) {
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
for (auto k = 0; k < K; k++)
vals[k] = src_data[n * K + k];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (auto e = row_start; e < row_end; e++)
for (auto k = 0; k < K; k++)
out_data[offset + e * K + k] = vals[k];
}
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce);
torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
#pragma once
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#pragma once
#define ATOMIC(NAME) \
template <typename scalar, size_t size> struct Atomic##NAME##IntegerImpl; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 1> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = (uint32_t *)(address - ((size_t)address & 3)); \
uint32_t old = *address_as_ui; \
uint32_t shift = ((size_t)address & 3) * 8; \
uint32_t sum; \
uint32_t assumed; \
\
do { \
assumed = old; \
sum = OP(val, scalar((old >> shift) & 0xff)); \
old = (old & ~(0x000000ff << shift)) | (sum << shift); \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 2> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = \
(uint32_t *)((char *)address - ((size_t)address & 2)); \
uint32_t old = *address_as_ui; \
uint32_t sum; \
uint32_t newval; \
uint32_t assumed; \
\
do { \
assumed = old; \
sum = OP(val, (size_t)address & 2 ? scalar(old >> 16) \
: scalar(old & 0xffff)); \
newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) \
: (old & 0xffff0000) | sum; \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = (uint32_t *)address; \
uint32_t old = *address_as_ui; \
uint32_t assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ui, assumed, OP(val, (scalar)old)); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 8> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned long long *address_as_ull = (unsigned long long *)address; \
unsigned long long old = *address_as_ull; \
unsigned long long assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ull, assumed, OP(val, (scalar)old)); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
int *address_as_i = (int *)address; \
int old = *address_as_i; \
int assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_i, assumed, \
__float_as_int(OP(val, __int_as_float(assumed)))); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 8> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned long long int *address_as_ull = \
(unsigned long long int *)address; \
unsigned long long int old = *address_as_ull; \
unsigned long long int assumed; \
\
do { \
assumed = old; \
old = atomicCAS( \
address_as_ull, assumed, \
__double_as_longlong(OP(val, __longlong_as_double(assumed)))); \
} while (assumed != old); \
} \
};
#define OP(X, Y) Y + X
ATOMIC(Add)
#undef OP
static inline __device__ void atomAdd(uint8_t *address, uint8_t val) {
AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomAdd(int8_t *address, int8_t val) {
AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomAdd(int16_t *address, int16_t val) {
AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomAdd(int32_t *address, int32_t val) {
atomicAdd(address, val);
}
static inline __device__ void atomAdd(int64_t *address, int64_t val) {
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static inline __device__ void atomAdd(double *address, double val) {
AtomicAddDecimalImpl<double, sizeof(double)>()(address, val);
}
#else
static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val);
}
#endif
#define OP(X, Y) Y *X
ATOMIC(Mul)
#undef OP
static inline __device__ void atomMul(uint8_t *address, uint8_t val) {
AtomicMulIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMul(int8_t *address, int8_t val) {
AtomicMulIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMul(int16_t *address, int16_t val) {
AtomicMulIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMul(int32_t *address, int32_t val) {
AtomicMulIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
}
static inline __device__ void atomMul(int64_t *address, int64_t val) {
AtomicMulIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMul(float *address, float val) {
AtomicMulDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMul(double *address, double val) {
AtomicMulDecimalImpl<double, sizeof(double)>()(address, val);
}
#define OP(X, Y) Y / X
ATOMIC(Div)
#undef OP
static inline __device__ void atomDiv(uint8_t *address, uint8_t val) {
AtomicDivIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomDiv(int8_t *address, int8_t val) {
AtomicDivIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomDiv(int16_t *address, int16_t val) {
AtomicDivIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomDiv(int32_t *address, int32_t val) {
AtomicDivIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
}
static inline __device__ void atomDiv(int64_t *address, int64_t val) {
AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomDiv(float *address, float val) {
AtomicDivDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomDiv(double *address, double val) {
AtomicDivDecimalImpl<double, sizeof(double)>()(address, val);
}
#define OP(X, Y) max(Y, X)
ATOMIC(Max)
#undef OP
static inline __device__ void atomMax(uint8_t *address, uint8_t val) {
AtomicMaxIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMax(int8_t *address, int8_t val) {
AtomicMaxIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMax(int16_t *address, int16_t val) {
AtomicMaxIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMax(int32_t *address, int32_t val) {
atomicMax(address, val);
}
static inline __device__ void atomMax(int64_t *address, int64_t val) {
AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMax(float *address, float val) {
AtomicMaxDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMax(double *address, double val) {
AtomicMaxDecimalImpl<double, sizeof(double)>()(address, val);
}
#define OP(X, Y) min(Y, X)
ATOMIC(Min)
#undef OP
static inline __device__ void atomMin(uint8_t *address, uint8_t val) {
AtomicMinIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMin(int8_t *address, int8_t val) {
AtomicMinIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMin(int16_t *address, int16_t val) {
AtomicMinIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMin(int32_t *address, int32_t val) {
atomicMin(address, val);
}
static inline __device__ void atomMin(int64_t *address, int64_t val) {
AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMin(float *address, float val) {
AtomicMinDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMin(double *address, double val) {
AtomicMinDecimalImpl<double, sizeof(double)>()(address, val);
}
#pragma once
#include <ATen/cuda/detail/TensorInfo.cuh>
// We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`.
template <typename scalar_t> struct IndexPtrToOffset {
static inline __host__ __device__ int
get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
#pragma once
#include <limits>
#include <map>
#include "atomics.cuh"
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
{"div", DIV}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
const ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
const ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline __host__ __device__ void update(scalar_t *val,
scalar_t new_val) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
}
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
static inline __device__ void atomic_write(scalar_t *address, scalar_t val) {
if (REDUCE == SUM || REDUCE == MEAN)
atomAdd(address, val);
else if (REDUCE == MUL)
atomMul(address, val);
else if (REDUCE == DIV)
atomDiv(address, val);
else if (REDUCE == MIN && val < *address)
atomMin(address, val);
else if (REDUCE == MAX && val > *address)
atomMax(address, val);
}
};
#include "segment_csr_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "index_info.cuh"
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void
segment_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, int64_t *arg_out_data, size_t N,
size_t E) {
// Each warp processes exactly `32/TB` rows and aggregates all row values
// via a parallel reduction.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset);
int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int64_t src_idx = row_start + lane_idx; src_idx < row_end;
src_idx += TB) {
Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg,
src_idx);
}
#pragma unroll
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
}
if (lane_idx == 0) {
Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
arg_out_data + row_idx, arg,
row_end - row_start);
}
}
}
template <typename scalar_t, ReductionType REDUCE>
__global__ void segment_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
// Each thread processes exactly one row. It turned out that is more
// efficient than using shared memory due to avoiding synchronization
// barriers.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset);
int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int64_t src_idx = row_start; src_idx < row_end; src_idx++) {
Reducer<scalar_t, REDUCE>::update(
&val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx);
}
Reducer<scalar_t, REDUCE>::write(out_data + thread_idx, val,
arg_out_data + thread_idx, arg,
row_end - row_start);
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
} else {
sizes = src.sizes().vec();
sizes[dim] = indptr.size(dim) - 1;
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (K == 1) {
segment_csr_kernel<scalar_t, REDUCE, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, E);
} else {
segment_csr_broadcast_kernel<scalar_t, REDUCE>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, K, E);
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, int TB>
__global__ void
gather_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx % TB;
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = __ldg(src_data + row_idx);
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) {
out_data[offset + out_idx] = val; // "Mostly" coalesced.
}
}
}
template <typename scalar_t>
__global__ void gather_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = src_data[thread_idx]; // Coalesced.
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int out_idx = row_start; out_idx < row_end; out_idx++) {
out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced.
}
}
}
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto d_gather_size = indptr.flatten()[-1].data_ptr<int64_t>();
auto h_gather_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_gather_size, d_gather_size, sizeof(int64_t),
cudaMemcpyDeviceToHost);
auto sizes = src.sizes().vec();
sizes[dim] = *h_gather_size;
out = torch::empty(sizes, src.options());
}
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
if (K == 1)
gather_csr_kernel<scalar_t, 4><<<BLOCKS(1, 4 * N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E);
else
gather_csr_broadcast_kernel<scalar_t>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, K, E);
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce);
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#include <torch/script.h>
#include "cpu/segment_csr_cpu.h"
#ifdef WITH_CUDA
#include "cuda/segment_csr_cuda.h"
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_fw(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return segment_csr_cuda(src, indptr, optional_out, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return segment_csr_cpu(src, indptr, optional_out, reduce);
}
}
torch::Tensor gather_csr_fw(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return gather_csr_cuda(src, indptr, optional_out);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return gather_csr_cpu(src, indptr, optional_out);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SegmentSumCSR : public torch::autograd::Function<SegmentSumCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = std::get<0>(segment_csr_fw(src, indptr, optional_out, "sum"));
ctx->save_for_backward({indptr});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_csr_fw(grad_out, indptr, grad_in);
return {grad_in, Variable(), Variable()};
}
};
class SegmentMeanCSR : public torch::autograd::Function<SegmentMeanCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = std::get<0>(segment_csr_fw(src, indptr, optional_out, "mean"));
ctx->save_for_backward({indptr});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_csr_fw(grad_out, indptr, grad_in);
auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1);
auto indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1);
auto count = (indptr2 - indptr1).to(grad_in.options());
count = gather_csr_fw(count, indptr, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++)
count = count.unsqueeze(-1);
grad_in.div_(count);
return {grad_in, Variable(), Variable()};
}
};
class SegmentMinCSR : public torch::autograd::Function<SegmentMinCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_csr_fw(src, indptr, optional_out, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({indptr, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
src_shape[indptr.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(indptr.dim() - 1, 0, src_shape[indptr.dim() - 1] - 1);
return {grad_in, Variable(), Variable()};
}
};
class SegmentMaxCSR : public torch::autograd::Function<SegmentMaxCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_csr_fw(src, indptr, optional_out, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({indptr, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
src_shape[indptr.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(indptr.dim() - 1, 0, src_shape[indptr.dim() - 1] - 1);
return {grad_in, Variable(), Variable()};
}
};
class GatherCSR : public torch::autograd::Function<GatherCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = gather_csr_fw(src, indptr, optional_out);
ctx->save_for_backward({indptr});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::empty(src_shape, grad_out.options());
segment_csr_fw(grad_out, indptr, grad_in, "sum");
return {grad_in, Variable(), Variable()};
}
};
torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentSumCSR::apply(src, indptr, optional_out)[0];
}
torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentMeanCSR::apply(src, indptr, optional_out)[0];
}
std::tuple<torch::Tensor, torch::Tensor>
segment_min_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
auto result = SegmentMinCSR::apply(src, indptr, optional_out);
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
segment_max_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
auto result = SegmentMaxCSR::apply(src, indptr, optional_out);
return std::make_tuple(result[0], result[1]);
}
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return GatherCSR::apply(src, indptr, optional_out)[0];
}
static auto registry =
torch::RegisterOperators()
.op("torch_scatter::segment_sum_csr", &segment_sum_csr)
.op("torch_scatter::segment_mean_csr", &segment_mean_csr)
.op("torch_scatter::segment_min_csr", &segment_min_csr)
.op("torch_scatter::segment_max_csr", &segment_max_csr)
.op("torch_scatter::gather_csr", &gather_csr);
import platform import os
import os.path as osp import os.path as osp
from glob import glob import sys
import glob
from setuptools import setup, find_packages from setuptools import setup, find_packages
from sys import argv
import torch import torch
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
# Windows users: Edit both of these to contain your VS include path, i.e.:
# cxx_extra_compile_args = ['-I{VISUAL_STUDIO_DIR}\\include']
# nvcc_extra_compile_args = [..., '-I{VISUAL_STUDIO_DIR}\\include']
cxx_extra_compile_args = []
nvcc_extra_compile_args = ['-arch=sm_35', '--expt-relaxed-constexpr']
# Windows users: Edit both of these to contain your VS library path, i.e.: def get_extensions():
# cxx_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}'] this_dir = osp.dirname(osp.abspath(__file__))
# nvcc_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}'] extensions_dir = osp.join(this_dir, 'csrc')
cxx_extra_link_args = []
nvcc_extra_link_args = []
if platform.system() != 'Windows': main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
cxx_extra_compile_args += ['-Wno-unused-variable'] cpu_files = glob.glob(osp.join(extensions_dir, 'cpu', '*.cpp'))
TORCH_MAJOR = int(torch.__version__.split('.')[0]) cuda_files = glob.glob(osp.join(extensions_dir, 'cuda', '*.cu'))
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
cxx_extra_compile_args += ['-DVERSION_GE_1_3']
nvcc_extra_compile_args += ['-DVERSION_GE_1_3']
cmdclass = {
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
}
ext_modules = [] Extension = CppExtension
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))] sources = main_files + cpu_files
ext_modules += [
CppExtension(f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'],
extra_compile_args=cxx_extra_compile_args,
extra_link_args=cxx_extra_link_args) for ext in exts
]
if CUDA_HOME is not None and '--cpu' not in argv: define_macros = []
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cuda', '*.cpp'))] extra_compile_args = {'cxx': [], 'nvcc': []}
ext_modules += [ # Windows users: Edit both of these to contain your VS include path, i.e.:
CUDAExtension( # extra_compile_args['cxx'] += ['-I{VISUAL_STUDIO_DIR}\\include']
f'torch_scatter.{ext}_cuda', # extra_compile_args['nvcc'] += ['-I{VISUAL_STUDIO_DIR}\\include']
[f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'], extra_compile_args={
'cxx': cxx_extra_compile_args, if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv(
'nvcc': nvcc_extra_compile_args, 'FORCE_CUDA', '0') == '1':
}, extra_link_args=nvcc_extra_link_args) for ext in exts
Extension = CUDAExtension
sources += cuda_files
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
extra_compile_args['cxx'] += ['-O0']
extra_compile_args['nvcc'] += nvcc_flags
if sys.platform == 'win32':
extra_compile_args['cxx'] += ['/MP']
return [
Extension(
'torch_scatter._C',
sources,
include_dirs=[extensions_dir],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
] ]
if '--cpu' in argv:
argv.remove('--cpu')
__version__ = '1.5.0'
url = 'https://github.com/rusty1s/pytorch_scatter'
install_requires = [] install_requires = []
setup_requires = ['pytest-runner'] setup_requires = ['pytest-runner']
...@@ -61,17 +59,19 @@ tests_require = ['pytest', 'pytest-cov'] ...@@ -61,17 +59,19 @@ tests_require = ['pytest', 'pytest-cov']
setup( setup(
name='torch_scatter', name='torch_scatter',
version=__version__, version='1.5.0',
description='PyTorch Extension Library of Optimized Scatter Operations',
author='Matthias Fey', author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de', author_email='matthias.fey@tu-dortmund.de',
url=url, url='https://github.com/rusty1s/pytorch_scatter',
download_url='{}/archive/{}.tar.gz'.format(url, __version__), description='PyTorch Extension Library of Optimized Scatter Operations',
keywords=['pytorch', 'scatter', 'segment'], keywords=['pytorch', 'scatter', 'segment', 'gather'],
license='MIT',
install_requires=install_requires, install_requires=install_requires,
setup_requires=setup_requires, setup_requires=setup_requires,
tests_require=tests_require, tests_require=tests_require,
ext_modules=ext_modules, ext_modules=get_extensions(),
cmdclass=cmdclass, cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
},
packages=find_packages(), packages=find_packages(),
) )
...@@ -3,7 +3,7 @@ from itertools import product ...@@ -3,7 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_scatter import segment_coo, segment_csr from torch_scatter import segment_csr
from .utils import tensor, dtypes, devices from .utils import tensor, dtypes, devices
...@@ -88,12 +88,12 @@ def test_forward(test, reduce, dtype, device): ...@@ -88,12 +88,12 @@ def test_forward(test, reduce, dtype, device):
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device) expected = tensor(test[reduce], dtype, device)
out = segment_coo(src, index, reduce=reduce) # out = segment_coo(src, index, reduce=reduce)
if isinstance(out, tuple): # if isinstance(out, tuple):
out, arg_out = out # out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) # arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected) # assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected) # assert torch.all(out == expected)
out = segment_csr(src, indptr, reduce=reduce) out = segment_csr(src, indptr, reduce=reduce)
if isinstance(out, tuple): if isinstance(out, tuple):
...@@ -111,7 +111,7 @@ def test_backward(test, reduce, device): ...@@ -111,7 +111,7 @@ def test_backward(test, reduce, device):
index = tensor(test['index'], torch.long, device) index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
assert gradcheck(segment_coo, (src, index, None, None, reduce)) # assert gradcheck(segment_coo, (src, index, None, None, reduce))
assert gradcheck(segment_csr, (src, indptr, None, reduce)) assert gradcheck(segment_csr, (src, indptr, None, reduce))
...@@ -130,22 +130,22 @@ def test_segment_out(test, reduce, dtype, device): ...@@ -130,22 +130,22 @@ def test_segment_out(test, reduce, dtype, device):
segment_csr(src, indptr, out, reduce=reduce) segment_csr(src, indptr, out, reduce=reduce)
assert torch.all(out == expected) assert torch.all(out == expected)
out.fill_(-2) # out.fill_(-2)
segment_coo(src, index, out, reduce=reduce) # segment_coo(src, index, out, reduce=reduce)
if reduce == 'sum': # if reduce == 'sum':
expected = expected - 2 # expected = expected - 2
elif reduce == 'mean': # elif reduce == 'mean':
expected = out # We can not really test this here. # expected = out # We can not really test this here.
elif reduce == 'min': # elif reduce == 'min':
expected = expected.fill_(-2) # expected = expected.fill_(-2)
elif reduce == 'max': # elif reduce == 'max':
expected[expected == 0] = -2 # expected[expected == 0] = -2
else: # else:
raise ValueError # raise ValueError
assert torch.all(out == expected) # assert torch.all(out == expected)
@pytest.mark.parametrize('test,reduce,dtype,device', @pytest.mark.parametrize('test,reduce,dtype,device',
...@@ -163,12 +163,12 @@ def test_non_contiguous_segment(test, reduce, dtype, device): ...@@ -163,12 +163,12 @@ def test_non_contiguous_segment(test, reduce, dtype, device):
if indptr.dim() > 1: if indptr.dim() > 1:
indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)
out = segment_coo(src, index, reduce=reduce) # out = segment_coo(src, index, reduce=reduce)
if isinstance(out, tuple): # if isinstance(out, tuple):
out, arg_out = out # out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) # arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected) # assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected) # assert torch.all(out == expected)
out = segment_csr(src, indptr, reduce=reduce) out = segment_csr(src, indptr, reduce=reduce)
if isinstance(out, tuple): if isinstance(out, tuple):
......
import torch from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
segment_min_csr, segment_max_csr, segment_csr,
gather_csr)
from .add import scatter_add __version__ = '1.5.0'
from .sub import scatter_sub
from .mul import scatter_mul
from .div import scatter_div
from .mean import scatter_mean
from .std import scatter_std
from .max import scatter_max
from .min import scatter_min
from .logsumexp import scatter_logsumexp
from .segment import segment_coo, segment_csr
from .gather import gather_coo, gather_csr
import torch_scatter.composite
torch.ops.load_library('torch_scatter/scatter_cpu.so')
torch.ops.load_library('torch_scatter/segment_csr_cpu.so')
torch.ops.load_library('torch_scatter/segment_coo_cpu.so')
try:
torch.ops.load_library('torch_scatter/scatter_cuda.so')
# torch.ops.load_library('torch_scatter/segment_csr_cuda.so')
# torch.ops.load_library('torch_scatter/segment_coo_cuda.so')
except OSError as e:
if torch.cuda.is_available():
raise e
__version__ = '1.4.0'
__all__ = [ __all__ = [
'scatter_add', 'segment_sum_csr',
'scatter_sub', 'segment_add_csr',
'scatter_mul', 'segment_mean_csr',
'scatter_div', 'segment_min_csr',
'scatter_mean', 'segment_max_csr',
'scatter_std', 'segment_max_csr',
'scatter_max',
'scatter_min',
'scatter_logsumexp',
'segment_coo',
'segment_csr', 'segment_csr',
'gather_coo',
'gather_csr', 'gather_csr',
'torch_scatter', 'torch_scatter',
'__version__', '__version__',
......
from typing import Optional, Tuple
import torch
torch.ops.load_library('torch_scatter/_C.so')
@torch.jit.script
def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out)
@torch.jit.script
def segment_add_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out)
@torch.jit.script
def segment_mean_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_mean_csr(src, indptr, out)
@torch.jit.script
def segment_min_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_min_csr(src, indptr, out)
@torch.jit.script
def segment_max_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_max_csr(src, indptr, out)
@torch.jit.script
def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None,
reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add':
return segment_sum_csr(src, indptr, out)
elif reduce == 'mean':
return segment_mean_csr(src, indptr, out)
elif reduce == 'min':
return segment_min_csr(src, indptr, out)[0]
elif reduce == 'max':
return segment_max_csr(src, indptr, out)[0]
else:
raise ValueError
@torch.jit.script
def gather_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.gather_csr(src, indptr, out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment