".github/vscode:/vscode.git/clone" did not exist on "ad2c52d72a996732bdcd3f6bfe2afbdaa4a2b19e"
Commit 97f2f4e9 authored by quyuanhao123's avatar quyuanhao123
Browse files

Initial commit

parents
Pipeline #189 failed with stages
in 0 seconds
#include "hip/hip_runtime.h"
#include "diag_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 1024
__global__ void non_diag_mask_kernel(const int64_t *row_data,
const int64_t *col_data, bool *out_data,
int64_t N, int64_t k, int64_t num_diag,
int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
int64_t r = row_data[thread_idx], c = col_data[thread_idx];
if (k < 0) {
if (r + k < 0) {
out_data[thread_idx] = true;
} else if (r + k >= N) {
out_data[thread_idx + num_diag] = true;
} else if (r + k > c) {
out_data[thread_idx + r + k] = true;
} else if (r + k < c) {
out_data[thread_idx + r + k + 1] = true;
}
} else {
if (r + k >= N) {
out_data[thread_idx + num_diag] = true;
} else if (r + k > c) {
out_data[thread_idx + r] = true;
} else if (r + k < c) {
out_data[thread_idx + r + 1] = true;
}
}
}
}
torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t M, int64_t N, int64_t k) {
CHECK_CUDA(row);
CHECK_CUDA(col);
hipSetDevice(row.get_device());
auto E = row.size(0);
auto num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto row_data = row.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
auto mask_data = mask.data_ptr<bool>();
if (E == 0)
return mask;
auto stream = at::cuda::getCurrentCUDAStream();
non_diag_mask_kernel<<<(E + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
row_data, col_data, mask_data, N, k, num_diag, E);
return mask;
}
#include "hip/hip_runtime.h"
#include "diag_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 1024
__global__ void non_diag_mask_kernel(const int64_t *row_data,
const int64_t *col_data, bool *out_data,
int64_t N, int64_t k, int64_t num_diag,
int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
int64_t r = row_data[thread_idx], c = col_data[thread_idx];
if (k < 0) {
if (r + k < 0) {
out_data[thread_idx] = true;
} else if (r + k >= N) {
out_data[thread_idx + num_diag] = true;
} else if (r + k > c) {
out_data[thread_idx + r + k] = true;
} else if (r + k < c) {
out_data[thread_idx + r + k + 1] = true;
}
} else {
if (r + k >= N) {
out_data[thread_idx + num_diag] = true;
} else if (r + k > c) {
out_data[thread_idx + r] = true;
} else if (r + k < c) {
out_data[thread_idx + r + 1] = true;
}
}
}
}
torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t M, int64_t N, int64_t k) {
CHECK_CUDA(row);
CHECK_CUDA(col);
hipSetDevice(row.get_device());
auto E = row.size(0);
auto num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto row_data = row.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
auto mask_data = mask.data_ptr<bool>();
if (E == 0)
return mask;
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
hipLaunchKernelGGL(( non_diag_mask_kernel), dim3((E + THREADS - 1) / THREADS), dim3(THREADS), 0, stream,
row_data, col_data, mask_data, N, k, num_diag, E);
return mask;
}
#pragma once
#include <limits>
#include <map>
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
{"div", DIV}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
const ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
const ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (scalar_t)(count > 0 ? count : 1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
};
#pragma once
#include "../extensions.h"
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length);
#include "hip/hip_runtime.h"
#include "rw_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void uniform_random_walk_kernel(const int64_t *rowptr,
const int64_t *col,
const int64_t *start,
const float *rand, int64_t *out,
int64_t walk_length, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
int64_t cur = start[thread_idx];
out[thread_idx] = cur;
int64_t row_start, row_end;
for (int64_t l = 0; l < walk_length; l++) {
row_start = rowptr[cur], row_end = rowptr[cur + 1];
cur = col[row_start +
int64_t(rand[l * numel + thread_idx] * (row_end - row_start))];
out[(l + 1) * numel + thread_idx] = cur;
}
}
}
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(start);
hipSetDevice(rowptr.get_device());
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto rand = torch::rand({walk_length, start.size(0)},
start.options().dtype(torch::kFloat));
auto out = torch::full({walk_length + 1, start.size(0)}, -1, start.options());
auto stream = at::cuda::getCurrentCUDAStream();
uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
start.data_ptr<int64_t>(), rand.data_ptr<float>(),
out.data_ptr<int64_t>(), walk_length, start.numel());
return out.t().contiguous();
}
#include "hip/hip_runtime.h"
#include "rw_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void uniform_random_walk_kernel(const int64_t *rowptr,
const int64_t *col,
const int64_t *start,
const float *rand, int64_t *out,
int64_t walk_length, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
int64_t cur = start[thread_idx];
out[thread_idx] = cur;
int64_t row_start, row_end;
for (int64_t l = 0; l < walk_length; l++) {
row_start = rowptr[cur], row_end = rowptr[cur + 1];
cur = col[row_start +
int64_t(rand[l * numel + thread_idx] * (row_end - row_start))];
out[(l + 1) * numel + thread_idx] = cur;
}
}
}
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(start);
hipSetDevice(rowptr.get_device());
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto rand = torch::rand({walk_length, start.size(0)},
start.options().dtype(torch::kFloat));
auto out = torch::full({walk_length + 1, start.size(0)}, -1, start.options());
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
hipLaunchKernelGGL(( uniform_random_walk_kernel), dim3(BLOCKS(start.numel())), dim3(THREADS), 0, stream,
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
start.data_ptr<int64_t>(), rand.data_ptr<float>(),
out.data_ptr<int64_t>(), walk_length, start.numel());
return out.t().contiguous();
}
#pragma once
#include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce);
torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce);
template<typename T>
__device__ T __ldg(const T* ptr) {
return *ptr;
}
#include "hip/hip_runtime.h"
#include "spmm_hip.h"
#include <ATen/hip/HIPContext.h>
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 256
#define FULL_MASK 0xffffffff
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm
template <typename scalar_t, ReductionType REDUCE, bool HAS_VALUE>
__global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
const scalar_t *value_data,
const scalar_t *mat_data, scalar_t *out_data,
int64_t *arg_out_data, int B, int M, int N, int K) {
// We ignore blockIdx.y here, because threads
// across `blockIdx.y` are treated equally.
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int row = thread_idx >> 5; // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
int batch_idx = row / M;
// Compute the column index of `mat` in which the thread is operating.
int mat_col_idx = lane_idx + (blockIdx.y << 5);
// Compute the output index (row-major order).
int out_idx = row * K + mat_col_idx;
// Helper arrays for warp communication.
int mat_row, mat_rows[32];
scalar_t val, vals[HAS_VALUE ? 32 : 1];
// Do not aggregate/write across the Y-axis (lane_idx < leftover).
int leftover = K - (blockIdx.y << 5);
if (batch_idx < B) {
int row_start = __ldg(rowptr_data + (row % M));
int row_end = __ldg(rowptr_data + (row % M) + 1);
int col_idx = row_start + lane_idx;
scalar_t result = Reducer<scalar_t, REDUCE>::init();
int64_t arg;
// Iterate over all `col` indices in parallel within a warp.
for (int c = row_start; c < row_end; c += 32) {
if (col_idx < row_end) {
// Coalesced memory access into `col` and `val`.
mat_row = __ldg(col_data + col_idx) * K;
if (HAS_VALUE)
val = __ldg(value_data + col_idx);
} else {
mat_row = -1;
if (HAS_VALUE)
val = (scalar_t)0;
}
col_idx += 32;
#pragma unroll
for (int i = 0; i < 32; i++) {
// Communication between all threads in a warp.
mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
if (HAS_VALUE)
vals[i] = __shfl_sync(FULL_MASK, val, i);
}
#pragma unroll
for (int i = 0; i < 32; i++) {
if (lane_idx < leftover && mat_rows[i] != -1) {
// Coalesced memory access into `mat`.
val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
if (HAS_VALUE)
val = vals[i] * val;
Reducer<scalar_t, REDUCE>::update(&result, val, &arg, c + i);
}
}
}
if (lane_idx < leftover) {
// Coalesced write into `out`.
Reducer<scalar_t, REDUCE>::write(out_data + out_idx, result,
arg_out_data + out_idx, arg,
row_end - row_start);
}
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
if (optional_value.has_value())
CHECK_CUDA(optional_value.value());
CHECK_CUDA(mat);
hipSetDevice(rowptr.get_device());
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
if (optional_value.has_value()) {
CHECK_INPUT(optional_value.value().dim() == 1);
CHECK_INPUT(optional_value.value().size(0) == col.size(0));
}
CHECK_INPUT(mat.dim() >= 2);
mat = mat.contiguous();
auto sizes = mat.sizes().vec();
sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto out = torch::empty(sizes, mat.options());
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, col.numel(), rowptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto M = rowptr.numel() - 1;
auto N = mat.size(-2);
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (optional_value.has_value()) {
auto value_data = optional_value.value().data_ptr<scalar_t>();
spmm_kernel<scalar_t, REDUCE, true><<<BLOCKS, THREADS, 0, stream>>>(
rowptr_data, col_data, value_data, mat_data, out_data, arg_out_data,
B, M, N, K);
} else {
spmm_kernel<scalar_t, REDUCE, false><<<BLOCKS, THREADS, 0, stream>>>(
rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, B,
M, N, K);
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, ReductionType REDUCE>
__global__ void
spmm_value_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
const int64_t *col_data, const scalar_t *mat_data,
const scalar_t *grad_data, scalar_t *out_data, int B,
int M, int N, int E, int K) {
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int index_idx = (thread_idx >> 5); // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
if (index_idx < E) {
int row = __ldg(row_data + index_idx);
int col = __ldg(col_data + index_idx);
scalar_t val = (scalar_t)0;
for (int b = 0; b < B; b++) {
for (int k = lane_idx; k < K; k += 32) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
}
#pragma unroll
for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
val += __shfl_down_sync(FULL_MASK, val, i);
}
if (lane_idx == 0) {
if (REDUCE == MEAN) {
int row_start = __ldg(rowptr_data + row);
int row_end = __ldg(rowptr_data + row + 1);
val /= (scalar_t)max(row_end - row_start, 1);
}
out_data[index_idx] = val;
}
}
}
torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
CHECK_CUDA(row);
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(mat);
CHECK_CUDA(grad);
hipSetDevice(row.get_device());
mat = mat.contiguous();
grad = grad.contiguous();
auto M = grad.size(-2);
auto N = mat.size(-2);
auto E = row.numel();
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);
auto out = torch::zeros(row.numel(), grad.options());
auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
spmm_value_bw_kernel<scalar_t, REDUCE><<<BLOCKS, THREADS, 0, stream>>>(
row_data, rowptr_data, col_data, mat_data, grad_data, out_data, B, M,
N, E, K);
});
});
return out;
}
#include "hip/hip_runtime.h"
#include "spmm_hip.h"
#include <ATen/hip/HIPContext.h>
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 256
#define FULL_MASK 0xffffffff
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm
template <typename scalar_t, ReductionType REDUCE, bool HAS_VALUE>
__global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
const scalar_t *value_data,
const scalar_t *mat_data, scalar_t *out_data,
int64_t *arg_out_data, int B, int M, int N, int K) {
// We ignore blockIdx.y here, because threads
// across `blockIdx.y` are treated equally.
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int row = thread_idx >> 5; // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
int batch_idx = row / M;
// Compute the column index of `mat` in which the thread is operating.
int mat_col_idx = lane_idx + (blockIdx.y << 5);
// Compute the output index (row-major order).
int out_idx = row * K + mat_col_idx;
// Helper arrays for warp communication.
int mat_row, mat_rows[32];
scalar_t val, vals[HAS_VALUE ? 32 : 1];
// Do not aggregate/write across the Y-axis (lane_idx < leftover).
int leftover = K - (blockIdx.y << 5);
if (batch_idx < B) {
int row_start = __ldg(rowptr_data + (row % M));
int row_end = __ldg(rowptr_data + (row % M) + 1);
int col_idx = row_start + lane_idx;
scalar_t result = Reducer<scalar_t, REDUCE>::init();
int64_t arg;
// Iterate over all `col` indices in parallel within a warp.
for (int c = row_start; c < row_end; c += 32) {
if (col_idx < row_end) {
// Coalesced memory access into `col` and `val`.
mat_row = __ldg(col_data + col_idx) * K;
if (HAS_VALUE)
val = __ldg(value_data + col_idx);
} else {
mat_row = -1;
if (HAS_VALUE)
val = (scalar_t)0;
}
col_idx += 32;
#pragma unroll
for (int i = 0; i < 32; i++) {
// Communication between all threads in a warp.
mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
if (HAS_VALUE)
vals[i] = __shfl_sync(FULL_MASK, val, i);
}
#pragma unroll
for (int i = 0; i < 32; i++) {
if (lane_idx < leftover && mat_rows[i] != -1) {
// Coalesced memory access into `mat`.
val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
if (HAS_VALUE)
val = vals[i] * val;
Reducer<scalar_t, REDUCE>::update(&result, val, &arg, c + i);
}
}
}
if (lane_idx < leftover) {
// Coalesced write into `out`.
Reducer<scalar_t, REDUCE>::write(out_data + out_idx, result,
arg_out_data + out_idx, arg,
row_end - row_start);
}
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
if (optional_value.has_value())
CHECK_CUDA(optional_value.value());
CHECK_CUDA(mat);
hipSetDevice(rowptr.get_device());
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
if (optional_value.has_value()) {
CHECK_INPUT(optional_value.value().dim() == 1);
CHECK_INPUT(optional_value.value().size(0) == col.size(0));
}
CHECK_INPUT(mat.dim() >= 2);
mat = mat.contiguous();
auto sizes = mat.sizes().vec();
sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto out = torch::empty(sizes, mat.options());
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, col.numel(), rowptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto M = rowptr.numel() - 1;
auto N = mat.size(-2);
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (optional_value.has_value()) {
auto value_data = optional_value.value().data_ptr<scalar_t>();
hipLaunchKernelGGL(( spmm_kernel<scalar_t, REDUCE, true>), dim3(BLOCKS), dim3(THREADS), 0, stream,
rowptr_data, col_data, value_data, mat_data, out_data, arg_out_data,
B, M, N, K);
} else {
hipLaunchKernelGGL(( spmm_kernel<scalar_t, REDUCE, false>), dim3(BLOCKS), dim3(THREADS), 0, stream,
rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, B,
M, N, K);
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, ReductionType REDUCE>
__global__ void
spmm_value_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
const int64_t *col_data, const scalar_t *mat_data,
const scalar_t *grad_data, scalar_t *out_data, int B,
int M, int N, int E, int K) {
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int index_idx = (thread_idx >> 5); // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
if (index_idx < E) {
int row = __ldg(row_data + index_idx);
int col = __ldg(col_data + index_idx);
scalar_t val = (scalar_t)0;
for (int b = 0; b < B; b++) {
for (int k = lane_idx; k < K; k += 32) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
}
#pragma unroll
for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
val += __shfl_down_sync(FULL_MASK, val, i);
}
if (lane_idx == 0) {
if (REDUCE == MEAN) {
int row_start = __ldg(rowptr_data + row);
int row_end = __ldg(rowptr_data + row + 1);
val /= (scalar_t)max(row_end - row_start, 1);
}
out_data[index_idx] = val;
}
}
}
torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
CHECK_CUDA(row);
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(mat);
CHECK_CUDA(grad);
hipSetDevice(row.get_device());
mat = mat.contiguous();
grad = grad.contiguous();
auto M = grad.size(-2);
auto N = mat.size(-2);
auto E = row.numel();
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);
auto out = torch::zeros(row.numel(), grad.options());
auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
hipLaunchKernelGGL(( spmm_value_bw_kernel<scalar_t, REDUCE>), dim3(BLOCKS), dim3(THREADS), 0, stream,
row_data, rowptr_data, col_data, mat_data, grad_data, out_data, B, M,
N, E, K);
});
});
return out;
}
#pragma once
#include "../extensions.h"
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce);
#include "spspmm_hip.h"
#include <ATen/hip/HIPContext.h>
#include <hipsparse.h>
#include "utils.cuh"
#define AT_DISPATCH_CUSPARSE_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case torch::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
hipsparseScsrgemm2_bufferSizeExt; \
const auto &cusparsecsrgemm2 = hipsparseScsrgemm2; \
return __VA_ARGS__(); \
} \
case torch::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
hipsparseDcsrgemm2_bufferSizeExt; \
const auto &cusparsecsrgemm2 = hipsparseDcsrgemm2; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce) {
CHECK_CUDA(rowptrA);
CHECK_CUDA(colA);
if (optional_valueA.has_value())
CHECK_CUDA(optional_valueA.value());
CHECK_CUDA(rowptrB);
CHECK_CUDA(colB);
if (optional_valueB.has_value())
CHECK_CUDA(optional_valueB.value());
hipSetDevice(rowptrA.get_device());
CHECK_INPUT(rowptrA.dim() == 1);
CHECK_INPUT(colA.dim() == 1);
if (optional_valueA.has_value()) {
CHECK_INPUT(optional_valueA.value().dim() == 1);
CHECK_INPUT(optional_valueA.value().size(0) == colA.size(0));
}
CHECK_INPUT(rowptrB.dim() == 1);
CHECK_INPUT(colB.dim() == 1);
if (optional_valueB.has_value()) {
CHECK_INPUT(optional_valueB.value().dim() == 1);
CHECK_INPUT(optional_valueB.value().size(0) == colB.size(0));
}
if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA =
torch::ones(colA.numel(), optional_valueB.value().options());
if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB =
torch::ones(colB.numel(), optional_valueA.value().options());
auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value())
scalar_type = optional_valueA.value().scalar_type();
auto handle = at::cuda::getCurrentCUDASparseHandle();
hipsparseMatDescr_t descr;
hipsparseCreateMatDescr(&descr);
rowptrA = rowptrA.toType(torch::kInt);
colA = colA.toType(torch::kInt);
rowptrB = rowptrB.toType(torch::kInt);
colB = colB.toType(torch::kInt);
int64_t M = rowptrA.numel() - 1, N = rowptrB.numel() - 1;
auto rowptrA_data = rowptrA.data_ptr<int>();
auto colA_data = colA.data_ptr<int>();
auto rowptrB_data = rowptrB.data_ptr<int>();
auto colB_data = colB.data_ptr<int>();
torch::Tensor rowptrC, colC;
torch::optional<torch::Tensor> optional_valueC = torch::nullopt;
int nnzC;
int *nnzTotalDevHostPtr = &nnzC;
// Step 1: Create an opaque structure.
csrgemm2Info_t info = NULL;
hipsparseCreateCsrgemm2Info(&info);
// Step 2: Allocate buffer for `csrgemm2Nnz` and `csrgemm2`.
size_t bufferSize;
AT_DISPATCH_CUSPARSE_TYPES(scalar_type, [&] {
scalar_t alpha = (scalar_t)1.0;
cusparsecsrgemm2_bufferSizeExt(handle, M, N, K, &alpha, descr, colA.numel(),
rowptrA_data, colA_data, descr, colB.numel(),
rowptrB_data, colB_data, NULL, descr, 0,
NULL, NULL, info, &bufferSize);
void *buffer = NULL;
hipMalloc(&buffer, bufferSize);
// Step 3: Compute CSR row pointer.
rowptrC = torch::empty(M + 1, rowptrA.options());
auto rowptrC_data = rowptrC.data_ptr<int>();
hipsparseXcsrgemm2Nnz(handle, M, N, K, descr, colA.numel(), rowptrA_data,
colA_data, descr, colB.numel(), rowptrB_data,
colB_data, descr, 0, NULL, NULL, descr, rowptrC_data,
nnzTotalDevHostPtr, info, buffer);
// Step 4: Compute CSR entries.
colC = torch::empty(nnzC, rowptrC.options());
auto colC_data = colC.data_ptr<int>();
if (optional_valueA.has_value())
optional_valueC = torch::empty(nnzC, optional_valueA.value().options());
scalar_t *valA_data = NULL, *valB_data = NULL, *valC_data = NULL;
if (optional_valueA.has_value()) {
valA_data = optional_valueA.value().data_ptr<scalar_t>();
valB_data = optional_valueB.value().data_ptr<scalar_t>();
valC_data = optional_valueC.value().data_ptr<scalar_t>();
}
cusparsecsrgemm2(handle, M, N, K, &alpha, descr, colA.numel(), valA_data,
rowptrA_data, colA_data, descr, colB.numel(), valB_data,
rowptrB_data, colB_data, NULL, descr, 0, NULL, NULL, NULL,
descr, valC_data, rowptrC_data, colC_data, info, buffer);
hipFree(buffer);
});
// Step 5: Destroy the opaque structure.
hipsparseDestroyCsrgemm2Info(info);
rowptrC = rowptrC.toType(torch::kLong);
colC = colC.toType(torch::kLong);
return std::make_tuple(rowptrC, colC, optional_valueC);
}
#pragma once
#include "../extensions.h"
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
__device__ __inline__ at::Half
__shfl_sync(const unsigned mask, const at::Half var, const int srcLane) {
return __shfl_sync(mask, (__half)var, srcLane);
}
__device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
const at::Half var,
const unsigned int delta) {
return __shfl_down_sync(mask, (__half)var, delta);
}
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/metis_cpu.h"
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__metis_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; }
#endif
#endif
#endif
SPARSE_API torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_HIP
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return partition_cpu(rowptr, col, optional_value, torch::nullopt, num_parts,
recursive);
}
}
SPARSE_API torch::Tensor partition2(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_HIP
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return partition_cpu(rowptr, col, optional_value, optional_node_weight,
num_parts, recursive);
}
}
SPARSE_API torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive,
int64_t num_workers) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_HIP
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return mt_partition_cpu(rowptr, col, optional_value, optional_node_weight,
num_parts, recursive, num_workers);
}
}
static auto registry = torch::RegisterOperators()
.op("torch_sparse::partition", &partition)
.op("torch_sparse::partition2", &partition2)
.op("torch_sparse::mt_partition", &mt_partition);
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/neighbor_sample_cpu.h"
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__neighbor_sample_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__neighbor_sample_cpu(void) { return NULL; }
#endif
#endif
#endif
// Returns 'output_node', 'row', 'col', 'output_edge'
SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
neighbor_sample(const torch::Tensor &colptr, const torch::Tensor &row,
const torch::Tensor &input_node,
const std::vector<int64_t> num_neighbors, const bool replace,
const bool directed) {
return neighbor_sample_cpu(colptr, row, input_node, num_neighbors, replace,
directed);
}
SPARSE_API std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_sample(
const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops, const bool replace, const bool directed) {
return hetero_neighbor_sample_cpu(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, replace, directed);
}
static auto registry =
torch::RegisterOperators()
.op("torch_sparse::neighbor_sample", &neighbor_sample)
.op("torch_sparse::hetero_neighbor_sample", &hetero_neighbor_sample);
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/relabel_cpu.h"
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__relabel_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__relabel_cpu(void) { return NULL; }
#endif
#endif
#endif
SPARSE_API std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,
torch::Tensor idx) {
if (col.device().is_cuda()) {
#ifdef WITH_HIP
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return relabel_cpu(col, idx);
}
}
SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::Tensor idx, bool bipartite) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_HIP
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return relabel_one_hop_cpu(rowptr, col, optional_value, idx, bipartite);
}
}
static auto registry =
torch::RegisterOperators()
.op("torch_sparse::relabel", &relabel)
.op("torch_sparse::relabel_one_hop", &relabel_one_hop);
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/rw_cpu.h"
#ifdef WITH_HIP
#include "hip/rw_hip.h"
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__rw_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__rw_cpu(void) { return NULL; }
#endif
#endif
#endif
SPARSE_API torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_HIP
return random_walk_cuda(rowptr, col, start, walk_length);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return random_walk_cpu(rowptr, col, start, walk_length);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::random_walk", &random_walk);
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/saint_cpu.h"
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__saint_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__saint_cpu(void) { return NULL; }
#endif
#endif
#endif
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
subgraph(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
torch::Tensor col) {
if (idx.device().is_cuda()) {
#ifdef WITH_HIP
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return subgraph_cpu(idx, rowptr, row, col);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::saint_subgraph", &subgraph);
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/sample_cpu.h"
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__sample_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__sample_cpu(void) { return NULL; }
#endif
#endif
#endif
SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
int64_t num_neighbors, bool replace) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_HIP
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return sample_adj_cpu(rowptr, col, idx, num_neighbors, replace);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::sample_adj", &sample_adj);
#pragma once
#include <torch/library.h>
#ifdef _WIN32
#if defined(torchsparse_EXPORTS)
#define SPARSE_API __declspec(dllexport)
#else
#define SPARSE_API __declspec(dllimport)
#endif
#else
#define SPARSE_API
#endif
SPARSE_API int64_t cuda_version();
SPARSE_API torch::Tensor ind2ptr(torch::Tensor ind, int64_t M);
SPARSE_API torch::Tensor ptr2ind(torch::Tensor ptr, int64_t E);
SPARSE_API torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive);
SPARSE_API torch::Tensor partition2(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive);
SPARSE_API torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive,
int64_t num_workers);
SPARSE_API std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,
torch::Tensor idx);
SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::Tensor idx, bool bipartite);
SPARSE_API torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length);
SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
subgraph(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
torch::Tensor col);
SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
int64_t num_neighbors, bool replace);
SPARSE_API torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row,
torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value,
torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat);
SPARSE_API torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value,
torch::optional<torch::Tensor> opt_rowcount,
torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat);
SPARSE_API std::tuple<torch::Tensor, torch::Tensor>
spmm_min(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value, torch::Tensor mat);
SPARSE_API std::tuple<torch::Tensor, torch::Tensor>
spmm_max(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value, torch::Tensor mat);
SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_sum(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K);
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/spmm_cpu.h"
#ifdef WITH_HIP
#include "hip/spmm_hip.h"
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__spmm_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__spmm_cpu(void) { return NULL; }
#endif
#endif
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_fw(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_HIP
return spmm_cuda(rowptr, col, optional_value, mat, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spmm_cpu(rowptr, col, optional_value, mat, reduce);
}
}
torch::Tensor spmm_value_bw(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
if (row.device().is_cuda()) {
#ifdef WITH_HIP
return spmm_value_bw_cuda(row, rowptr, col, mat, grad, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spmm_value_bw_cpu(row, rowptr, col, mat, grad, reduce);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SPMMSum : public torch::autograd::Function<SPMMSum> {
public:
static variable_list forward(AutogradContext *ctx,
torch::optional<Variable> opt_row,
Variable rowptr, Variable col, Variable value,
torch::optional<Variable> opt_colptr,
torch::optional<Variable> opt_csr2csc,
Variable mat, bool has_value) {
if (has_value && torch::autograd::any_variable_requires_grad({value})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
}
if (torch::autograd::any_variable_requires_grad({mat})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
AT_ASSERTM(opt_colptr.has_value(), "Argument `colptr` is missing");
AT_ASSERTM(opt_csr2csc.has_value(), "Argument `csr2csc` is missing");
}
auto row = opt_row.has_value() ? opt_row.value() : col;
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : col;
auto csr2csc = opt_csr2csc.has_value() ? opt_csr2csc.value() : col;
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value;
auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "sum"));
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward({row, rowptr, col, value, colptr, csr2csc, mat});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3],
colptr = saved[4], csr2csc = saved[5], mat = saved[6];
auto grad_value = Variable();
if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "sum");
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value.view({-1, 1}).index_select(0, csr2csc).view(-1);
grad_mat = std::get<0>(spmm_fw(colptr, row.index_select(0, csr2csc),
opt_value, grad_out, "sum"));
}
return {Variable(), Variable(), Variable(), grad_value,
Variable(), Variable(), grad_mat, Variable()};
}
};
class SPMMMean : public torch::autograd::Function<SPMMMean> {
public:
static variable_list forward(AutogradContext *ctx,
torch::optional<Variable> opt_row,
Variable rowptr, Variable col, Variable value,
torch::optional<Variable> opt_rowcount,
torch::optional<Variable> opt_colptr,
torch::optional<Variable> opt_csr2csc,
Variable mat, bool has_value) {
if (has_value && torch::autograd::any_variable_requires_grad({value})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
}
if (torch::autograd::any_variable_requires_grad({mat})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
AT_ASSERTM(opt_rowcount.has_value(), "Argument `rowcount` is missing");
AT_ASSERTM(opt_colptr.has_value(), "Argument `colptr` is missing");
AT_ASSERTM(opt_csr2csc.has_value(), "Argument `csr2csc` is missing");
}
auto row = opt_row.has_value() ? opt_row.value() : col;
auto rowcount = opt_rowcount.has_value() ? opt_rowcount.value() : col;
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : col;
auto csr2csc = opt_csr2csc.has_value() ? opt_csr2csc.value() : col;
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value;
auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "mean"));
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward(
{row, rowptr, col, value, rowcount, colptr, csr2csc, mat});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3],
rowcount = saved[4], colptr = saved[5], csr2csc = saved[6],
mat = saved[7];
auto grad_value = Variable();
if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "mean");
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
row = row.index_select(0, csr2csc);
rowcount = rowcount.index_select(0, row).toType(mat.scalar_type());
rowcount.masked_fill_(rowcount < 1, 1);
if (has_value > 0)
rowcount =
value.view({-1, 1}).index_select(0, csr2csc).view(-1).div(rowcount);
else
rowcount.pow_(-1);
grad_mat = std::get<0>(spmm_fw(colptr, row, rowcount, grad_out, "sum"));
}
return {Variable(), Variable(), Variable(), grad_value, Variable(),
Variable(), Variable(), grad_mat, Variable()};
}
};
class SPMMMin : public torch::autograd::Function<SPMMMin> {
public:
static variable_list forward(AutogradContext *ctx, Variable rowptr,
Variable col, Variable value, Variable mat,
bool has_value) {
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value;
auto result = spmm_fw(rowptr, col, opt_value, mat, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward({col, value, mat, arg_out});
ctx->mark_non_differentiable({arg_out});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto col = saved[0], value = saved[1], mat = saved[2], arg_out = saved[3];
auto invalid_arg_mask = arg_out == col.size(0);
arg_out = arg_out.masked_fill(invalid_arg_mask, 0);
auto grad_value = Variable();
if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
auto ind = col.index_select(0, arg_out.flatten()).view_as(arg_out);
auto out = mat.gather(-2, ind);
out.mul_(grad_out);
out.masked_fill_(invalid_arg_mask, 0);
grad_value = torch::zeros_like(value);
grad_value.scatter_add_(0, arg_out.flatten(), out.flatten());
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
if (has_value > 0) {
value = value.view({-1, 1})
.index_select(0, arg_out.flatten())
.view_as(arg_out)
.mul_(grad_out);
} else
value = grad_out;
value.masked_fill_(invalid_arg_mask, 0);
auto ind = col.index_select(0, arg_out.flatten()).view_as(arg_out);
grad_mat = torch::zeros_like(mat);
grad_mat.scatter_add_(-2, ind, value);
}
return {Variable(), Variable(), grad_value, grad_mat, Variable()};
}
};
class SPMMMax : public torch::autograd::Function<SPMMMax> {
public:
static variable_list forward(AutogradContext *ctx, Variable rowptr,
Variable col, Variable value, Variable mat,
bool has_value) {
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value;
auto result = spmm_fw(rowptr, col, opt_value, mat, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward({col, value, mat, arg_out});
ctx->mark_non_differentiable({arg_out});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto col = saved[0], value = saved[1], mat = saved[2], arg_out = saved[3];
auto invalid_arg_mask = arg_out == col.size(0);
arg_out = arg_out.masked_fill(invalid_arg_mask, 0);
auto grad_value = Variable();
if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
auto ind = col.index_select(0, arg_out.flatten()).view_as(arg_out);
auto out = mat.gather(-2, ind);
out.mul_(grad_out);
out.masked_fill_(invalid_arg_mask, 0);
grad_value = torch::zeros_like(value);
grad_value.scatter_add_(0, arg_out.flatten(), out.flatten());
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
if (has_value > 0) {
value = value.view({-1, 1})
.index_select(0, arg_out.flatten())
.view_as(arg_out)
.mul_(grad_out);
} else
value = grad_out;
value.masked_fill_(invalid_arg_mask, 0);
auto ind = col.index_select(0, arg_out.flatten()).view_as(arg_out);
grad_mat = torch::zeros_like(mat);
grad_mat.scatter_add_(-2, ind, value);
}
return {Variable(), Variable(), grad_value, grad_mat, Variable()};
}
};
SPARSE_API torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row,
torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value,
torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : col;
return SPMMSum::apply(opt_row, rowptr, col, value, opt_colptr, opt_csr2csc,
mat, opt_value.has_value())[0];
}
SPARSE_API torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value,
torch::optional<torch::Tensor> opt_rowcount,
torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : col;
return SPMMMean::apply(opt_row, rowptr, col, value, opt_rowcount, opt_colptr,
opt_csr2csc, mat, opt_value.has_value())[0];
}
SPARSE_API std::tuple<torch::Tensor, torch::Tensor>
spmm_min(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value, torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : col;
auto result = SPMMMin::apply(rowptr, col, value, mat, opt_value.has_value());
return std::make_tuple(result[0], result[1]);
}
SPARSE_API std::tuple<torch::Tensor, torch::Tensor>
spmm_max(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value, torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : col;
auto result = SPMMMax::apply(rowptr, col, value, mat, opt_value.has_value());
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators()
.op("torch_sparse::spmm_sum", &spmm_sum)
.op("torch_sparse::spmm_mean", &spmm_mean)
.op("torch_sparse::spmm_min", &spmm_min)
.op("torch_sparse::spmm_max", &spmm_max);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment