Commit 6a2a503e authored by quyuanhao123's avatar quyuanhao123
Browse files

Initial commit

parents
Pipeline #191 failed with stages
#include "hip/hip_runtime.h"
#include "segment_coo_hip.h"
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/detail/IndexUtils.cuh>
#include <ATen/hip/detail/TensorInfo.cuh>
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__ void
segment_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t N) {
// Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate
// result via atomics.
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
int lane_idx = row_idx & (32 - 1);
int D = index_info.sizes[index_info.dims - 1];
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int64_t idx = index_info.data[offset], next_idx;
int out_idx = (row_idx / D) * N + idx;
scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;
#pragma unroll
for (int i = 1; i < 32; i *= 2) {
// Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, val, i);
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
assert(idx >= next_idx);
if (idx == next_idx)
Reducer<scalar_t, REDUCE>::update(&val, tmp);
}
}
next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D ||
idx != next_idx)
Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val);
}
}
template <typename scalar_t>
__global__ void segment_coo_arg_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
int D = index_info.sizes[index_info.dims - 1];
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int64_t idx = index_info.data[offset];
int out_idx = (row_idx / D) * N + idx;
scalar_t val = __ldg(out_data + out_idx);
if (src_data[row_idx] == val)
arg_out_data[out_idx] = row_idx % D;
}
}
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void segment_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t K, size_t N) {
// Each thread processes a single column and `TB` index entries. Coalesced
// read and write is performed in column-major order. The intermediate
// results are written via atomics.
int D = index_info.sizes[index_info.dims - 1];
int E_1 = E / D;
int E_2 = (D - 1) + TB - ((D - 1) % TB);
int row_idx = blockIdx.x * blockDim.y + threadIdx.y;
int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
int dim_start = (row_idx * TB) / E_2;
int row_start = (row_idx * TB) % E_2;
if (dim_start < E_1 && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
dim_start * D + row_start, index_info);
int idx1 = __ldg(index_info.data + offset), idx2;
scalar_t val = src_data[K * (dim_start * D + row_start) + col_idx];
#pragma unroll
for (int i = 1; i < TB; i++) {
if (row_start + i >= D)
break;
idx2 = __ldg(index_info.data + offset +
i * index_info.strides[index_info.dims - 1]);
assert(idx1 <= idx2);
if (idx1 == idx2) {
Reducer<scalar_t, REDUCE>::update(
&val, src_data[K * (dim_start * D + row_start + i) + col_idx]);
} else {
Reducer<scalar_t, REDUCE>::atomic_write(
out_data + (dim_start * N + idx1) * K + col_idx, val);
val = src_data[K * (dim_start * D + row_start + i) + col_idx];
}
idx1 = idx2;
}
Reducer<scalar_t, REDUCE>::atomic_write(
out_data + (dim_start * N + idx1) * K + col_idx, val);
}
}
template <typename scalar_t>
__global__ void segment_coo_arg_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K, size_t N) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int col_idx = thread_idx % K;
int D = index_info.sizes[index_info.dims - 1];
if (row_idx < E && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = __ldg(index_info.data + offset);
int out_idx = ((row_idx / D) * N + idx) * K + col_idx;
scalar_t val = __ldg(out_data + out_idx);
if (src_data[thread_idx] == val)
arg_out_data[out_idx] = row_idx % D;
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
hipSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= index.dim());
auto sizes = index.sizes().vec();
for (int i = 0; i < index.dim(); i++) {
sizes[i] = src.size(i);
}
index = index.expand(sizes);
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else {
auto tmp = index.select(dim, index.size(dim) - 1);
tmp = tmp.numel() > 1 ? tmp.max() : tmp;
sizes[dim] = 1 + tmp.cpu().data_ptr<int64_t>()[0];
}
out = torch::zeros(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
} else if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec();
sizes[dim] = out.size(dim);
arg_out = torch::zeros(sizes, out.options());
}
if (index.numel() == 0)
return std::make_tuple(out, arg_out);
auto E = index.numel();
auto E_2 = index.size(dim);
auto E_1 = index.numel() / E_2;
auto K = src.numel() / E;
auto N = out.size(dim);
auto avg_len = (float)E_2 / (float)N;
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
if (K == 1)
segment_coo_kernel<scalar_t, REDUCE, true>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
out_data, E, N);
else if (avg_len <= 8)
segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
<<<dim3((E_1 * ((E_2 + 3) / 4) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
else if (avg_len <= 16)
segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
<<<dim3((E_1 * ((E_2 + 7) / 8) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
else if (avg_len <= 32)
segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
<<<dim3((E_1 * ((E_2 + 15) / 16) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
else
segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
<<<dim3((E_1 * ((E_2 + 31) / 32) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MIN || REDUCE == MAX) {
if (K == 1)
segment_coo_arg_kernel<scalar_t>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, N);
else
segment_coo_arg_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, K, N);
}
if (REDUCE == MEAN) {
auto count_data = arg_out.value().data_ptr<scalar_t>();
segment_coo_kernel<scalar_t, SUM, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, E, N);
arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1,
(scalar_t)1);
auto count = arg_out.value();
for (int i = dim + 1; i < out.dim(); i++)
count = count.unsqueeze(-1);
if (out.is_floating_point())
out.true_divide_(count);
else
out.div_(count, "floor");
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t>
__global__ void
gather_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t N) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int row = index_info.data[offset];
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N;
scalar_t val = __ldg(src_data + offset + row);
out_data[row_idx] = val;
}
}
template <typename scalar_t>
__global__ void gather_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t K, size_t N) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int col_idx = thread_idx % K;
if (thread_idx < E * K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int row = index_info.data[offset];
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K;
scalar_t val = __ldg(src_data + offset + K * row + col_idx);
out_data[thread_idx] = val;
}
}
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
hipSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= index.dim());
auto sizes = index.sizes().vec();
for (auto i = 0; i < index.dim() - 1; i++)
sizes[i] = src.size(i);
index = index.expand(sizes);
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < src.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(index.size(dim) == out.size(dim));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = index.size(dim);
out = torch::empty(sizes, src.options());
}
if (index.numel() == 0)
return out;
auto E = index.numel();
auto K = out.numel() / E;
auto N = src.size(dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
if (K == 1)
gather_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, N);
else
gather_coo_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(src_data, index_info,
out_data, E, K, N);
});
return out;
}
#include "hip/hip_runtime.h"
#include "segment_coo_hip.h"
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/detail/IndexUtils.cuh>
#include <ATen/hip/detail/TensorInfo.cuh>
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__ void
segment_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t N) {
// Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate
// result via atomics.
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
int lane_idx = row_idx & (32 - 1);
int D = index_info.sizes[index_info.dims - 1];
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int64_t idx = index_info.data[offset], next_idx;
int out_idx = (row_idx / D) * N + idx;
scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;
#pragma unroll
for (int i = 1; i < 32; i *= 2) {
// Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, val, i);
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
assert(idx >= next_idx);
if (idx == next_idx)
Reducer<scalar_t, REDUCE>::update(&val, tmp);
}
}
next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D ||
idx != next_idx)
Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val);
}
}
template <typename scalar_t>
__global__ void segment_coo_arg_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
int D = index_info.sizes[index_info.dims - 1];
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int64_t idx = index_info.data[offset];
int out_idx = (row_idx / D) * N + idx;
scalar_t val = __ldg(out_data + out_idx);
if (src_data[row_idx] == val)
arg_out_data[out_idx] = row_idx % D;
}
}
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void segment_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t K, size_t N) {
// Each thread processes a single column and `TB` index entries. Coalesced
// read and write is performed in column-major order. The intermediate
// results are written via atomics.
int D = index_info.sizes[index_info.dims - 1];
int E_1 = E / D;
int E_2 = (D - 1) + TB - ((D - 1) % TB);
int row_idx = blockIdx.x * blockDim.y + threadIdx.y;
int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
int dim_start = (row_idx * TB) / E_2;
int row_start = (row_idx * TB) % E_2;
if (dim_start < E_1 && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
dim_start * D + row_start, index_info);
int idx1 = __ldg(index_info.data + offset), idx2;
scalar_t val = src_data[K * (dim_start * D + row_start) + col_idx];
#pragma unroll
for (int i = 1; i < TB; i++) {
if (row_start + i >= D)
break;
idx2 = __ldg(index_info.data + offset +
i * index_info.strides[index_info.dims - 1]);
assert(idx1 <= idx2);
if (idx1 == idx2) {
Reducer<scalar_t, REDUCE>::update(
&val, src_data[K * (dim_start * D + row_start + i) + col_idx]);
} else {
Reducer<scalar_t, REDUCE>::atomic_write(
out_data + (dim_start * N + idx1) * K + col_idx, val);
val = src_data[K * (dim_start * D + row_start + i) + col_idx];
}
idx1 = idx2;
}
Reducer<scalar_t, REDUCE>::atomic_write(
out_data + (dim_start * N + idx1) * K + col_idx, val);
}
}
template <typename scalar_t>
__global__ void segment_coo_arg_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K, size_t N) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int col_idx = thread_idx % K;
int D = index_info.sizes[index_info.dims - 1];
if (row_idx < E && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = __ldg(index_info.data + offset);
int out_idx = ((row_idx / D) * N + idx) * K + col_idx;
scalar_t val = __ldg(out_data + out_idx);
if (src_data[thread_idx] == val)
arg_out_data[out_idx] = row_idx % D;
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
hipSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= index.dim());
auto sizes = index.sizes().vec();
for (int i = 0; i < index.dim(); i++) {
sizes[i] = src.size(i);
}
index = index.expand(sizes);
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else {
auto tmp = index.select(dim, index.size(dim) - 1);
tmp = tmp.numel() > 1 ? tmp.max() : tmp;
sizes[dim] = 1 + tmp.cpu().data_ptr<int64_t>()[0];
}
out = torch::zeros(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
} else if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec();
sizes[dim] = out.size(dim);
arg_out = torch::zeros(sizes, out.options());
}
if (index.numel() == 0)
return std::make_tuple(out, arg_out);
auto E = index.numel();
auto E_2 = index.size(dim);
auto E_1 = index.numel() / E_2;
auto K = src.numel() / E;
auto N = out.size(dim);
auto avg_len = (float)E_2 / (float)N;
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
if (K == 1)
hipLaunchKernelGGL(( segment_coo_kernel<scalar_t, REDUCE, true>)
, dim3(BLOCKS(1, E)), dim3(THREADS), 0, stream, src_data, index_info,
out_data, E, N);
else if (avg_len <= 8)
hipLaunchKernelGGL(( segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>)
, dim3(dim3((E_1 * ((E_2 + 3) / 4) + 7) / 8, (K + 31) / 32)),
dim3(dim3(32, 8)), 0, stream, src_data, index_info, out_data, E, K,
N);
else if (avg_len <= 16)
hipLaunchKernelGGL(( segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>)
, dim3(dim3((E_1 * ((E_2 + 7) / 8) + 7) / 8, (K + 31) / 32)),
dim3(dim3(32, 8)), 0, stream, src_data, index_info, out_data, E, K,
N);
else if (avg_len <= 32)
hipLaunchKernelGGL(( segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>)
, dim3(dim3((E_1 * ((E_2 + 15) / 16) + 7) / 8, (K + 31) / 32)),
dim3(dim3(32, 8)), 0, stream, src_data, index_info, out_data, E, K,
N);
else
hipLaunchKernelGGL(( segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>)
, dim3(dim3((E_1 * ((E_2 + 31) / 32) + 7) / 8, (K + 31) / 32)),
dim3(dim3(32, 8)), 0, stream, src_data, index_info, out_data, E, K,
N);
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MIN || REDUCE == MAX) {
if (K == 1)
hipLaunchKernelGGL(( segment_coo_arg_kernel<scalar_t>)
, dim3(BLOCKS(1, E)), dim3(THREADS), 0, stream,
src_data, index_info, out_data, arg_out_data, E, N);
else
hipLaunchKernelGGL(( segment_coo_arg_broadcast_kernel<scalar_t>)
, dim3(BLOCKS(1, E * K)), dim3(THREADS), 0, stream,
src_data, index_info, out_data, arg_out_data, E, K, N);
}
if (REDUCE == MEAN) {
auto count_data = arg_out.value().data_ptr<scalar_t>();
hipLaunchKernelGGL(( segment_coo_kernel<scalar_t, SUM, false>)
, dim3(BLOCKS(1, E)), dim3(THREADS), 0, stream, nullptr, index_info,
count_data, E, N);
arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1,
(scalar_t)1);
auto count = arg_out.value();
for (int i = dim + 1; i < out.dim(); i++)
count = count.unsqueeze(-1);
if (out.is_floating_point())
out.true_divide_(count);
else
out.div_(count, "floor");
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t>
__global__ void
gather_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t N) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int row = index_info.data[offset];
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N;
scalar_t val = __ldg(src_data + offset + row);
out_data[row_idx] = val;
}
}
template <typename scalar_t>
__global__ void gather_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t K, size_t N) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int col_idx = thread_idx % K;
if (thread_idx < E * K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int row = index_info.data[offset];
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K;
scalar_t val = __ldg(src_data + offset + K * row + col_idx);
out_data[thread_idx] = val;
}
}
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
hipSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= index.dim());
auto sizes = index.sizes().vec();
for (auto i = 0; i < index.dim() - 1; i++)
sizes[i] = src.size(i);
index = index.expand(sizes);
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < src.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(index.size(dim) == out.size(dim));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = index.size(dim);
out = torch::empty(sizes, src.options());
}
if (index.numel() == 0)
return out;
auto E = index.numel();
auto K = out.numel() / E;
auto N = src.size(dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
if (K == 1)
hipLaunchKernelGGL(( gather_coo_kernel<scalar_t>), dim3(BLOCKS(1, E)), dim3(THREADS), 0, stream,
src_data, index_info, out_data, E, N);
else
hipLaunchKernelGGL(( gather_coo_broadcast_kernel<scalar_t>)
, dim3(BLOCKS(1, E * K)), dim3(THREADS), 0, stream, src_data, index_info,
out_data, E, K, N);
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce);
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
template<typename T>
__device__ T __ldg(const T* ptr) {
return *ptr;
}
#include "hip/hip_runtime.h"
#include "segment_csr_hip.h"
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/detail/IndexUtils.cuh>
#include <ATen/hip/detail/TensorInfo.cuh>
#include "index_info.cuh"
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void
segment_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, int64_t *arg_out_data, size_t N,
size_t E) {
// Each warp processes exactly `32/TB` rows and aggregates all row values
// via a parallel reduction.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset);
int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int64_t src_idx = row_start + lane_idx; src_idx < row_end;
src_idx += TB) {
Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg,
src_idx);
}
#pragma unroll
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
}
if (lane_idx == 0) {
Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
arg_out_data + row_idx, arg,
row_end - row_start);
}
}
}
template <typename scalar_t, ReductionType REDUCE>
__global__ void segment_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
// Each thread processes exactly one row. It turned out that is more
// efficient than using shared memory due to avoiding synchronization
// barriers.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset);
int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int64_t src_idx = row_start; src_idx < row_end; src_idx++) {
Reducer<scalar_t, REDUCE>::update(
&val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx);
}
Reducer<scalar_t, REDUCE>::write(out_data + thread_idx, val,
arg_out_data + thread_idx, arg,
row_end - row_start);
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
hipSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1);
} else {
sizes = src.sizes().vec();
sizes[dim] = std::max<int64_t>(indptr.size(dim) - 1, 0);
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out);
}
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (K == 1) {
segment_csr_kernel<scalar_t, REDUCE, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, E);
} else {
segment_csr_broadcast_kernel<scalar_t, REDUCE>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, K, E);
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, int TB>
__global__ void
gather_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx % TB;
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = __ldg(src_data + row_idx);
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) {
out_data[offset + out_idx] = val; // "Mostly" coalesced.
}
}
}
template <typename scalar_t>
__global__ void gather_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = src_data[thread_idx]; // Coalesced.
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int out_idx = row_start; out_idx < row_end; out_idx++) {
out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced.
}
}
}
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
hipSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
if (src.numel() > 0) {
sizes[dim] = indptr.flatten()[-1].cpu().data_ptr<int64_t>()[0];
} else {
sizes[dim] = 0;
}
out = torch::empty(sizes, src.options());
}
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return out;
}
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
if (K == 1)
gather_csr_kernel<scalar_t, 4><<<BLOCKS(1, 4 * N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E);
else
gather_csr_broadcast_kernel<scalar_t>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, K, E);
});
return out;
}
#include "hip/hip_runtime.h"
#include "segment_csr_hip.h"
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/detail/IndexUtils.cuh>
#include <ATen/hip/detail/TensorInfo.cuh>
#include "index_info.cuh"
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void
segment_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, int64_t *arg_out_data, size_t N,
size_t E) {
// Each warp processes exactly `32/TB` rows and aggregates all row values
// via a parallel reduction.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset);
int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int64_t src_idx = row_start + lane_idx; src_idx < row_end;
src_idx += TB) {
Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg,
src_idx);
}
#pragma unroll
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
}
if (lane_idx == 0) {
Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
arg_out_data + row_idx, arg,
row_end - row_start);
}
}
}
template <typename scalar_t, ReductionType REDUCE>
__global__ void segment_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
// Each thread processes exactly one row. It turned out that is more
// efficient than using shared memory due to avoiding synchronization
// barriers.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset);
int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int64_t src_idx = row_start; src_idx < row_end; src_idx++) {
Reducer<scalar_t, REDUCE>::update(
&val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx);
}
Reducer<scalar_t, REDUCE>::write(out_data + thread_idx, val,
arg_out_data + thread_idx, arg,
row_end - row_start);
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
hipSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1);
} else {
sizes = src.sizes().vec();
sizes[dim] = std::max<int64_t>(indptr.size(dim) - 1, 0);
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out);
}
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (K == 1) {
hipLaunchKernelGGL(( segment_csr_kernel<scalar_t, REDUCE, 1>)
, dim3(BLOCKS(32, N)), dim3(THREADS), 0, stream,
src_data, indptr_info, out_data, arg_out_data, N, E);
} else {
hipLaunchKernelGGL(( segment_csr_broadcast_kernel<scalar_t, REDUCE>)
, dim3(BLOCKS(1, N * K)), dim3(THREADS), 0, stream,
src_data, indptr_info, out_data, arg_out_data, N, K, E);
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, int TB>
__global__ void
gather_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx % TB;
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = __ldg(src_data + row_idx);
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) {
out_data[offset + out_idx] = val; // "Mostly" coalesced.
}
}
}
template <typename scalar_t>
__global__ void gather_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = src_data[thread_idx]; // Coalesced.
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int out_idx = row_start; out_idx < row_end; out_idx++) {
out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced.
}
}
}
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
hipSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
if (src.numel() > 0) {
sizes[dim] = indptr.flatten()[-1].cpu().data_ptr<int64_t>()[0];
} else {
sizes[dim] = 0;
}
out = torch::empty(sizes, src.options());
}
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return out;
}
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
if (K == 1)
hipLaunchKernelGGL(( gather_csr_kernel<scalar_t, 4>), dim3(BLOCKS(1, 4 * N)), dim3(THREADS), 0, stream,
src_data, indptr_info, out_data, N, E);
else
hipLaunchKernelGGL(( gather_csr_broadcast_kernel<scalar_t>)
, dim3(BLOCKS(1, N * K)), dim3(THREADS), 0, stream, src_data, indptr_info,
out_data, N, K, E);
});
return out;
}
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
__device__ __inline__ at::Half __shfl_up_sync(const unsigned mask,
const at::Half var,
const unsigned int delta) {
return __shfl_up_sync(mask, (__half)var, delta);
}
__device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
const at::Half var,
const unsigned int delta) {
return __shfl_down_sync(mask, (__half)var, delta);
}
#include <Python.h>
#include <torch/script.h>
#include "cpu/scatter_cpu.h"
#include "utils.h"
#ifdef WITH_HIP
#include "hip/scatter_hip.h"
#endif
#ifdef _WIN32
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__scatter_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__scatter_cpu(void) { return NULL; }
#endif
#endif
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
if (src.dim() == 1)
for (auto i = 0; i < dim; i++)
src = src.unsqueeze(0);
for (auto i = src.dim(); i < other.dim(); i++)
src = src.unsqueeze(-1);
src = src.expand(other.sizes().vec());
return src;
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_HIP
return scatter_cuda(src, index, dim, optional_out, dim_size, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return scatter_cpu(src, index, dim, optional_out, dim_size, reduce);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class ScatterSum : public torch::autograd::Function<ScatterSum> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out, dim, index, false);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMul : public torch::autograd::Function<ScatterMul> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "mul");
auto out = std::get<0>(result);
ctx->save_for_backward({src, index, out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto src = saved[0];
auto index = saved[1];
auto out = saved[2];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out * out, dim, index, false).div_(src);
grad_in.masked_fill_(grad_in.isnan(), 0);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMean : public torch::autograd::Function<ScatterMean> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
auto old_index = index;
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
auto ones = torch::ones(old_index.sizes(), src.options());
result = scatter_fw(ones, old_index,
old_index.dim() <= dim ? old_index.dim() - 1 : dim,
torch::nullopt, out.size(dim), "sum");
auto count = std::get<0>(result);
count.masked_fill_(count < 1, 1);
count = broadcast(count, out, dim);
if (out.is_floating_point())
out.true_divide_(count);
else
out.div_(count, "floor");
ctx->save_for_backward({index, count});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto count = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
count = torch::gather(count, dim, index, false);
auto grad_in = torch::gather(grad_out, dim, index, false);
grad_in.true_divide_(count);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMin : public torch::autograd::Function<ScatterMin> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMax : public torch::autograd::Function<ScatterMax> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
}
torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0];
}
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
}
std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMax::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators()
.op("torch_scatter::scatter_sum", &scatter_sum)
.op("torch_scatter::scatter_mul", &scatter_mul)
.op("torch_scatter::scatter_mean", &scatter_mean)
.op("torch_scatter::scatter_min", &scatter_min)
.op("torch_scatter::scatter_max", &scatter_max);
#pragma once
#include <torch/extension.h>
int64_t cuda_version();
torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
std::tuple<torch::Tensor, torch::Tensor>
segment_min_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
std::tuple<torch::Tensor, torch::Tensor>
segment_max_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out);
torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
std::tuple<torch::Tensor, torch::Tensor>
segment_min_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
std::tuple<torch::Tensor, torch::Tensor>
segment_max_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
#include <Python.h>
#include <torch/script.h>
#include "cpu/segment_coo_cpu.h"
#include "utils.h"
#ifdef WITH_HIP
#include "hip/segment_coo_hip.h"
#endif
#ifdef _WIN32
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__segment_coo_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__segment_coo_cpu(void) { return NULL; }
#endif
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_fw(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_HIP
return segment_coo_cuda(src, index, optional_out, dim_size, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return segment_coo_cpu(src, index, optional_out, dim_size, reduce);
}
}
torch::Tensor gather_coo_fw(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
if (src.device().is_cuda()) {
#ifdef WITH_HIP
return gather_coo_cuda(src, index, optional_out);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return gather_coo_cpu(src, index, optional_out);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SegmentSumCOO : public torch::autograd::Function<SegmentSumCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_coo_fw(grad_out, index, grad_in);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class SegmentMeanCOO : public torch::autograd::Function<SegmentMeanCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "mean");
auto out = std::get<0>(result);
auto count = std::get<1>(result).value();
ctx->save_for_backward({index, count});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto count = saved[1];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_coo_fw(grad_out, index, grad_in);
count = gather_coo_fw(count, index, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - index.dim(); i++)
count = count.unsqueeze(-1);
grad_in.true_divide_(count);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class SegmentMinCOO : public torch::autograd::Function<SegmentMinCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[index.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(index.dim() - 1, 0, src_shape[index.dim() - 1] - 1);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class SegmentMaxCOO : public torch::autograd::Function<SegmentMaxCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[index.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(index.dim() - 1, 0, src_shape[index.dim() - 1] - 1);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class GatherCOO : public torch::autograd::Function<GatherCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = gather_coo_fw(src, index, optional_out);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::zeros(src_shape, grad_out.options());
segment_coo_fw(grad_out, index, grad_in, torch::nullopt, "sum");
return {grad_in, Variable(), Variable()};
}
};
torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return SegmentSumCOO::apply(src, index, optional_out, dim_size)[0];
}
torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return SegmentMeanCOO::apply(src, index, optional_out, dim_size)[0];
}
std::tuple<torch::Tensor, torch::Tensor>
segment_min_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = SegmentMinCOO::apply(src, index, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
segment_max_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = SegmentMaxCOO::apply(src, index, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
return GatherCOO::apply(src, index, optional_out)[0];
}
static auto registry =
torch::RegisterOperators()
.op("torch_scatter::segment_sum_coo", &segment_sum_coo)
.op("torch_scatter::segment_mean_coo", &segment_mean_coo)
.op("torch_scatter::segment_min_coo", &segment_min_coo)
.op("torch_scatter::segment_max_coo", &segment_max_coo)
.op("torch_scatter::gather_coo", &gather_coo);
#include <Python.h>
#include <torch/script.h>
#include "cpu/segment_csr_cpu.h"
#include "utils.h"
#ifdef WITH_HIP
#include "hip/segment_csr_hip.h"
#endif
#ifdef _WIN32
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__segment_csr_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__segment_csr_cpu(void) { return NULL; }
#endif
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_fw(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_HIP
return segment_csr_cuda(src, indptr, optional_out, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return segment_csr_cpu(src, indptr, optional_out, reduce);
}
}
torch::Tensor gather_csr_fw(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
if (src.device().is_cuda()) {
#ifdef WITH_HIP
return gather_csr_cuda(src, indptr, optional_out);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return gather_csr_cpu(src, indptr, optional_out);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SegmentSumCSR : public torch::autograd::Function<SegmentSumCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = std::get<0>(segment_csr_fw(src, indptr, optional_out, "sum"));
ctx->save_for_backward({indptr});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_csr_fw(grad_out, indptr, grad_in);
return {grad_in, Variable(), Variable()};
}
};
class SegmentMeanCSR : public torch::autograd::Function<SegmentMeanCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = std::get<0>(segment_csr_fw(src, indptr, optional_out, "mean"));
ctx->save_for_backward({indptr});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
if (grad_in.numel() > 0) {
gather_csr_fw(grad_out, indptr, grad_in);
auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1);
auto indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1);
auto count = (indptr2 - indptr1).to(grad_in.options());
count = gather_csr_fw(count, indptr, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++)
count = count.unsqueeze(-1);
grad_in.true_divide_(count);
}
return {grad_in, Variable(), Variable()};
}
};
class SegmentMinCSR : public torch::autograd::Function<SegmentMinCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_csr_fw(src, indptr, optional_out, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({indptr, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto arg_out = saved[1];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[indptr.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(indptr.dim() - 1, 0, src_shape[indptr.dim() - 1] - 1);
return {grad_in, Variable(), Variable()};
}
};
class SegmentMaxCSR : public torch::autograd::Function<SegmentMaxCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_csr_fw(src, indptr, optional_out, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({indptr, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto arg_out = saved[1];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[indptr.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(indptr.dim() - 1, 0, src_shape[indptr.dim() - 1] - 1);
return {grad_in, Variable(), Variable()};
}
};
class GatherCSR : public torch::autograd::Function<GatherCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = gather_csr_fw(src, indptr, optional_out);
ctx->save_for_backward({indptr});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
segment_csr_fw(grad_out, indptr, grad_in, "sum");
return {grad_in, Variable(), Variable()};
}
};
torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentSumCSR::apply(src, indptr, optional_out)[0];
}
torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentMeanCSR::apply(src, indptr, optional_out)[0];
}
std::tuple<torch::Tensor, torch::Tensor>
segment_min_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
auto result = SegmentMinCSR::apply(src, indptr, optional_out);
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
segment_max_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
auto result = SegmentMaxCSR::apply(src, indptr, optional_out);
return std::make_tuple(result[0], result[1]);
}
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return GatherCSR::apply(src, indptr, optional_out)[0];
}
static auto registry =
torch::RegisterOperators()
.op("torch_scatter::segment_sum_csr", &segment_sum_csr)
.op("torch_scatter::segment_mean_csr", &segment_mean_csr)
.op("torch_scatter::segment_min_csr", &segment_min_csr)
.op("torch_scatter::segment_max_csr", &segment_max_csr)
.op("torch_scatter::gather_csr", &gather_csr);
#pragma once
#include <torch/script.h>
#include <vector>
inline std::vector<int64_t> list2vec(const c10::List<int64_t> list) {
std::vector<int64_t> result;
result.reserve(list.size());
for (size_t i = 0; i < list.size(); i++)
result.push_back(list[i]);
return result;
}
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_HIP
#include <hip/hip_runtime.h>
#endif
#ifdef _WIN32
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
#endif
#endif
int64_t cuda_version() {
#ifdef WITH_HIP
return TORCH_HIP_VERSION;
#else
return -1;
#endif
}
static auto registry =
torch::RegisterOperators().op("torch_scatter::cuda_version", &cuda_version);
#!/bin/bash
source ~/miniconda3/etc/profile.d/conda.sh
conda activate torch1.10_py39_dtk22.10
module purge
module load compiler/devtoolset/7.3.1 mpi/hpcx/gcc-7.3.1 #compiler/dtk/22.10.1
module list
source ~/dtk-22.10.1/env.sh
export C_INCLUDE_PATH=/public/software/apps/DeepLearning/PyTorch_Lib/gflags-2.1.2-build/include:$C_INCLUDE_PATH
export CPLUS_INCLUDE_PATH=/public/software/apps/DeepLearning/PyTorch_Lib/gflags-2.1.2-build/include:$CPLUS_INCLUDE_PATH
export C_INCLUDE_PATH=/public/software/apps/DeepLearning/PyTorch_Lib/glog-build/include:$C_INCLUDE_PATH
export CPLUS_INCLUDE_PATH=/public/software/apps/DeepLearning/PyTorch_Lib/glog-build/include:$CPLUS_INCLUDE_PATH
export C_INCLUDE_PATH=$ROCM_PATH/rocrand/include:$C_INCLUDE_PATH
export CPLUS_INCLUDE_PATH=$ROCM_PATH/rocrand/include:$CPLUS_INCLUDE_PATH
export LD_LIBRARY_PATH=$ROCM_PATH/rocrand/lib:$LD_LIBRARY_PATH
export FORCE_ONLY_HIP=1
export CC=hipcc
export CXX=hipcc
[metadata]
description-file = README.md
[aliases]
test = pytest
[tool:pytest]
addopts = --capture=no --cov
[egg_info]
tag_build =
tag_date = 0
import os
import sys
import glob
import os.path as osp
from itertools import product
from setuptools import setup, find_packages
import torch
from torch.__config__ import parallel_info
from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
WITH_HIP = torch.cuda.is_available() and CUDA_HOME is not None
suffices = ['cpu', 'cuda'] if WITH_HIP else ['cpu']
if os.getenv('FORCE_CUDA', '0') == '1':
suffices = ['cuda', 'cpu']
if os.getenv('FORCE_ONLY_HIP', '0') == '1':
suffices = ['hip']
if os.getenv('FORCE_ONLY_CPU', '0') == '1':
suffices = ['cpu']
ROCM_PATH = os.getenv('ROCM_PATH')
HIPLIB = osp.join(ROCM_PATH, 'hipsparse', 'include')
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions():
extensions = []
extensions_dir = osp.join('csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
for main, suffix in product(main_files, suffices):
define_macros = []
extra_compile_args = {'cxx': ['-O2']}
extra_link_args = ['-s']
info = parallel_info()
if ('backend: OpenMP' in info and 'OpenMP not found' not in info
and sys.platform != 'darwin'):
extra_compile_args['cxx'] += ['-DAT_PARALLEL_OPENMP']
if sys.platform == 'win32':
extra_compile_args['cxx'] += ['/openmp']
else:
extra_compile_args['cxx'] += ['-fopenmp']
else:
print('Compiling without OpenMP...')
if suffix == 'hip':
define_macros += [('WITH_HIP', None)]
hipcc_flags = os.getenv('HIPCC_FLAGS', '')
hipcc_flags = [] if hipcc_flags == '' else hipcc_flags.split(' ')
hipcc_flags += ['--expt-relaxed-constexpr', '-O2']
extra_compile_args['hipcc'] = hipcc_flags
name = main.split(os.sep)[-1][:-4]
sources = [main]
path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')
if osp.exists(path):
sources += [path]
path = osp.join(extensions_dir, 'hip', f'{name}_hip.hip')
if suffix == 'hip' and osp.exists(path):
sources += [path]
Extension = CppExtension if suffix == 'cpu' else CUDAExtension
define_macros += [('TORCH_HIP_VERSION', 10000), ('__HIP__', None), ('__HCC__', None)]
extension = Extension(
f'torch_scatter._{name}_{suffix}',
sources,
include_dirs=[extensions_dir, HIPLIB],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
extensions += [extension]
return extensions
install_requires = []
setup_requires = []
tests_require = ['pytest', 'pytest-runner', 'pytest-cov']
setup(
name='torch_scatter',
version='2.0.9',
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_scatter',
description='PyTorch Extension Library of Optimized Scatter Operations',
keywords=['pytorch', 'scatter', 'segment', 'gather'],
license='MIT',
python_requires='>=3.6',
install_requires=install_requires,
setup_requires=setup_requires,
tests_require=tests_require,
extras_require={'test': tests_require},
ext_modules=get_extensions() if not BUILD_DOCS else [],
cmdclass={
'build_ext':
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
},
packages=find_packages(),
)
Metadata-Version: 2.1
Name: torch-scatter
Version: 2.0.9
Summary: PyTorch Extension Library of Optimized Scatter Operations
Home-page: https://github.com/rusty1s/pytorch_scatter
Author: Matthias Fey
Author-email: matthias.fey@tu-dortmund.de
License: MIT
Keywords: pytorch,scatter,segment,gather
Requires-Python: >=3.6
Provides-Extra: test
License-File: LICENSE
LICENSE
MANIFEST.in
README.md
setup.cfg
setup.py
/work/home/quyuanhao123/software/test_ocp/torch_scatter-2.0.9/csrc/scatter.cpp
/work/home/quyuanhao123/software/test_ocp/torch_scatter-2.0.9/csrc/segment_coo.cpp
/work/home/quyuanhao123/software/test_ocp/torch_scatter-2.0.9/csrc/segment_csr.cpp
/work/home/quyuanhao123/software/test_ocp/torch_scatter-2.0.9/csrc/version.cpp
/work/home/quyuanhao123/software/test_ocp/torch_scatter-2.0.9/csrc/cpu/scatter_cpu.cpp
/work/home/quyuanhao123/software/test_ocp/torch_scatter-2.0.9/csrc/cpu/segment_coo_cpu.cpp
/work/home/quyuanhao123/software/test_ocp/torch_scatter-2.0.9/csrc/cpu/segment_csr_cpu.cpp
/work/home/quyuanhao123/software/test_ocp/torch_scatter-2.0.9/csrc/hip/scatter_hip_hip.hip
/work/home/quyuanhao123/software/test_ocp/torch_scatter-2.0.9/csrc/hip/segment_coo_hip_hip.hip
/work/home/quyuanhao123/software/test_ocp/torch_scatter-2.0.9/csrc/hip/segment_csr_hip_hip.hip
csrc/scatter.cpp
csrc/scatter.h
csrc/segment_coo.cpp
csrc/segment_csr.cpp
csrc/utils.h
csrc/version.cpp
csrc/cpu/index_info.h
csrc/cpu/reducer.h
csrc/cpu/scatter_cpu.cpp
csrc/cpu/scatter_cpu.h
csrc/cpu/segment_coo_cpu.cpp
csrc/cpu/segment_coo_cpu.h
csrc/cpu/segment_csr_cpu.cpp
csrc/cpu/segment_csr_cpu.h
csrc/cpu/utils.h
csrc/hip/atomics.cuh
csrc/hip/index_info.cuh
csrc/hip/reducer.cuh
csrc/hip/scatter_hip.h
csrc/hip/scatter_hip.hip
csrc/hip/scatter_hip_hip.hip
csrc/hip/segment_coo_hip.h
csrc/hip/segment_coo_hip.hip
csrc/hip/segment_coo_hip_hip.hip
csrc/hip/segment_csr_hip.h
csrc/hip/segment_csr_hip.hip
csrc/hip/segment_csr_hip_hip.hip
csrc/hip/utils.cuh
torch_scatter/__init__.py
torch_scatter/placeholder.py
torch_scatter/scatter.py
torch_scatter/segment_coo.py
torch_scatter/segment_csr.py
torch_scatter/utils.py
torch_scatter.egg-info/PKG-INFO
torch_scatter.egg-info/SOURCES.txt
torch_scatter.egg-info/dependency_links.txt
torch_scatter.egg-info/requires.txt
torch_scatter.egg-info/top_level.txt
torch_scatter/composite/__init__.py
torch_scatter/composite/logsumexp.py
torch_scatter/composite/softmax.py
torch_scatter/composite/std.py
\ No newline at end of file
[test]
pytest
pytest-runner
pytest-cov
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment