"docs/modelfile.mdx" did not exist on "294b6f5a220e8678c2b08fd2ab783a99e25c5215"
Commit 9725b043 authored by rusty1s's avatar rusty1s
Browse files

clean up reduction type

parent 9a91c42d
......@@ -10,17 +10,71 @@
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
#define ADD 0
#define MEAN 1
#define MIN 2
#define MAX 3
enum ReductionType { ADD, MEAN, MIN, MAX };
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
if (reduce == "add") { \
const ReductionType REDUCE = ADD; \
return __VA_ARGS__(); \
} else if (reduce == "mean") { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} else if (reduce == "min") { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} else if (reduce == "max") { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
return std::numeric_limits<scalar_t>::min();
} else {
return (scalar_t)0;
}
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
} else {
*val = *val + new_val;
}
}
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == ADD) {
*address = val;
} else if (REDUCE == MEAN) {
*address = val / (scalar_t)max(count, 1);
} else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else {
*address = (scalar_t)0;
}
}
}
};
// We need our own `IndexToOffset` implementation since we do not want to access
// the last element of the `indexptr`.
template <typename T, typename I> struct IndexPtrToOffset {
static __host__ __device__ I
get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
I offset = idx % (info.sizes[info.dims - 1] - 1);
template <typename scalar_t> struct IndexPtrToOffset {
static inline __host__ __device__ int
get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) {
......@@ -31,170 +85,85 @@ template <typename T, typename I> struct IndexPtrToOffset {
}
};
template <typename scalar_t, int REDUCE, int TB>
__global__ void segment_add_csr_kernel(
const scalar_t *src_data,
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void
segment_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t E) {
scalar_t *out_data, int64_t *arg_out_data, size_t N,
size_t E) {
// Each warp processes exactly `32/TB` rows.
// Each warp processes exactly `32/TB` rows and aggregates all row values via
// a parallel reduction.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val, tmp;
int64_t arg_val, arg_tmp;
if (REDUCE == ADD) {
val = (scalar_t)0;
} else if (REDUCE == MEAN) {
val = (scalar_t)0;
} else if (REDUCE == MIN) {
val = std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
val = std::numeric_limits<scalar_t>::min();
}
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg, tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
tmp = src_data[offset + src_idx]; // "Mostly" coalesced read.
if (REDUCE == ADD) {
val += tmp;
} else if (REDUCE == MEAN) {
val += tmp;
} else if (REDUCE == MIN && tmp < val) {
val = tmp;
arg_val = src_idx;
} else if (REDUCE == MAX && tmp > val) {
val = tmp;
arg_val = src_idx;
}
Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg,
src_idx);
}
#pragma unroll
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
tmp = __shfl_down_sync(FULL_MASK, val, i);
if (REDUCE == ADD) {
val += tmp;
} else if (REDUCE == MEAN) {
val += tmp;
} else if (REDUCE == MIN) {
arg_tmp = __shfl_down_sync(FULL_MASK, arg_val, i);
if (tmp < val) {
val = tmp;
arg_val = arg_tmp;
}
} else if (REDUCE == MAX) {
arg_tmp = __shfl_down_sync(FULL_MASK, arg_val, i);
if (tmp > val) {
val = tmp;
arg_val = arg_tmp;
}
if (REDUCE == MIN || REDUCE == MAX) {
tmp = __shfl_down_sync(FULL_MASK, arg, i);
}
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, tmp);
}
if (lane_idx == 0) {
// "Mostly" coalesced write.
if (REDUCE == ADD) {
out_data[row_idx] = val;
} else if (REDUCE == MEAN) {
out_data[row_idx] = val / (scalar_t)max(row_end - row_start, 1);
} else if (REDUCE == MIN) {
if (row_end - row_start > 0) {
out_data[row_idx] = val;
arg_out_data[row_idx] = arg_val;
} else {
out_data[row_idx] = 0;
}
} else if (REDUCE == MAX) {
if (row_end - row_start > 0) {
out_data[row_idx] = val;
arg_out_data[row_idx] = arg_val;
} else {
out_data[row_idx] = 0;
}
}
Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
arg_out_data + row_idx, arg,
row_end - row_start);
}
}
}
template <typename scalar_t, int REDUCE>
__global__ void segment_add_csr_broadcast_kernel(
template <typename scalar_t, ReductionType REDUCE>
__global__ void segment_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
// Each thread processes exactly one row. It turned out that is more efficient
// than using shared memory due to avoiding synchronization barriers.
// Each thread processes exactly one row. It turned out that is more
// efficient than using shared memory due to avoiding synchronization
// barriers.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val, tmp;
int64_t arg_val;
if (REDUCE == ADD) {
val = (scalar_t)0;
} else if (REDUCE == MEAN) {
val = (scalar_t)0;
} else if (REDUCE == MIN) {
val = std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
val = std::numeric_limits<scalar_t>::min();
}
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int src_idx = row_start; src_idx < row_end; src_idx++) {
tmp = src_data[offset + K * src_idx + lane_idx]; // Coalesced read.
if (REDUCE == ADD) {
val += tmp;
} else if (REDUCE == MEAN) {
val += tmp;
} else if (REDUCE == MIN && tmp < val) {
val = tmp;
arg_val = src_idx;
} else if (REDUCE == MAX && tmp > val) {
val = tmp;
arg_val = src_idx;
}
Reducer<scalar_t, REDUCE>::update(
&val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx);
}
// Coalesced write.
if (REDUCE == ADD) {
out_data[thread_idx] = val;
} else if (REDUCE == MEAN) {
out_data[thread_idx] = val / (scalar_t)max(row_end - row_start, 1);
} else if (REDUCE == MIN) {
if (row_end - row_start > 0) {
out_data[thread_idx] = val;
arg_out_data[thread_idx] = arg_val;
} else {
out_data[thread_idx] = 0;
}
} else if (REDUCE == MAX) {
if (row_end - row_start > 0) {
out_data[thread_idx] = val;
arg_out_data[thread_idx] = arg_val;
} else {
out_data[thread_idx] = 0;
}
}
Reducer<scalar_t, REDUCE>::write(out_data + thread_idx, val,
arg_out_data + thread_idx, arg,
row_end - row_start);
}
}
......@@ -223,14 +192,15 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
}
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") {
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(reduce_dim);
// auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
......@@ -238,54 +208,25 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
// Select the right kernel based on the reduce operation and whether we need
// broadcasting capabilties (K > 1):
if (K == 1 && reduce == "add") {
segment_add_csr_kernel<scalar_t, ADD, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, nullptr, N, E);
} else if (K == 1 && reduce == "mean") {
segment_add_csr_kernel<scalar_t, MEAN, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, nullptr, N, E);
} else if (K == 1 && reduce == "min") {
auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
segment_add_csr_kernel<scalar_t, MIN, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, arg_out_data, N, E);
} else if (K == 1 && reduce == "max") {
auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
segment_add_csr_kernel<scalar_t, MAX, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, arg_out_data, N, E);
} else if (reduce == "add") {
segment_add_csr_broadcast_kernel<scalar_t, ADD>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, nullptr, N, K, E);
} else if (reduce == "mean") {
segment_add_csr_broadcast_kernel<scalar_t, MEAN>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, nullptr, N, K, E);
} else if (reduce == "min") {
auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
segment_add_csr_broadcast_kernel<scalar_t, MIN>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, K, E);
} else if (reduce == "max") {
auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
segment_add_csr_broadcast_kernel<scalar_t, MAX>
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (K == 1) {
segment_csr_kernel<scalar_t, REDUCE, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, E);
} else {
segment_csr_broadcast_kernel<scalar_t, REDUCE>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, K, E);
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, int REDUCE>
__global__ void segment_add_coo_kernel(
const scalar_t *src_data,
template <typename scalar_t, ReductionType REDUCE>
__global__ void
segment_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E) {
......@@ -318,15 +259,15 @@ __global__ void segment_add_coo_kernel(
}
}
template <typename scalar_t, int REDUCE, int TB>
__global__ void segment_add_coo_broadcast_kernel(
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void segment_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K) {
// Each thread processes a single column and `TB` rows. Coalesced read and
// write is performed in column-major order. The intermediate results are
// written via atomics.
// Each thread processes a single column and `TB` index entries. Coalesced
// read and write is performed in column-major order. The intermediate
// results are written via atomics.
int row_start = (blockIdx.x * blockDim.y + threadIdx.y) * TB;
int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
......@@ -392,24 +333,34 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
// Select the right kernel based on average row length (purely heuristic)
// and whether we need broadcasting capabilties (K > 1):
if (K == 1)
segment_add_coo_kernel<scalar_t, ADD>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info, out_data,
nullptr, E);
else if (avg_len <= 8)
segment_add_coo_broadcast_kernel<scalar_t, ADD, 4>
if (K == 1 && reduce == "add") {
segment_coo_kernel<scalar_t, ADD><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, nullptr, E);
} else if (K == 1 && reduce == "mean") {
segment_coo_kernel<scalar_t, MEAN><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, nullptr, E);
} else if (K == 1 && reduce == "min") {
auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
segment_coo_kernel<scalar_t, MIN><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E);
} else if (K == 1 && reduce == "max") {
auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
segment_coo_kernel<scalar_t, MAX><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E);
} else if (avg_len <= 8)
segment_coo_broadcast_kernel<scalar_t, ADD, 4>
<<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8), 0,
stream>>>(src_data, index_info, out_data, nullptr, E, K);
else if (avg_len <= 16)
segment_add_coo_broadcast_kernel<scalar_t, ADD, 8>
segment_coo_broadcast_kernel<scalar_t, ADD, 8>
<<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8), 0,
stream>>>(src_data, index_info, out_data, nullptr, E, K);
else if (avg_len <= 32)
segment_add_coo_broadcast_kernel<scalar_t, ADD, 16>
segment_coo_broadcast_kernel<scalar_t, ADD, 16>
<<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
else
segment_add_coo_broadcast_kernel<scalar_t, ADD, 32>
segment_coo_broadcast_kernel<scalar_t, ADD, 32>
<<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment