Commit 26aee002 authored by rusty1s's avatar rusty1s
Browse files

spspmm done

parent e44a639f
#include "spspmm_cpu.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce) {
CHECK_CPU(rowptrA);
CHECK_CPU(colA);
if (optional_valueA.has_value())
CHECK_CPU(optional_valueA.value());
CHECK_CPU(rowptrB);
CHECK_CPU(colB);
if (optional_valueB.has_value())
CHECK_CPU(optional_valueB.value());
CHECK_INPUT(rowptrA.dim() == 1);
CHECK_INPUT(colA.dim() == 1);
if (optional_valueA.has_value()) {
CHECK_INPUT(optional_valueA.value().dim() == 1);
CHECK_INPUT(optional_valueA.value().size(0) == colA.size(0));
}
CHECK_INPUT(rowptrB.dim() == 1);
CHECK_INPUT(colB.dim() == 1);
if (optional_valueB.has_value()) {
CHECK_INPUT(optional_valueB.value().dim() == 1);
CHECK_INPUT(optional_valueB.value().size(0) == colB.size(0));
}
if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA =
torch::ones(colA.numel(), optional_valueB.value().options());
if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB =
torch::ones(colB.numel(), optional_valueA.value().options());
auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value())
scalar_type = optional_valueA.value().scalar_type();
auto rowptrA_data = rowptrA.data_ptr<int64_t>();
auto colA_data = colA.data_ptr<int64_t>();
auto rowptrB_data = rowptrB.data_ptr<int64_t>();
auto colB_data = colB.data_ptr<int64_t>();
// Pass 1: Compute CSR row pointer.
auto rowptrC = torch::empty_like(rowptrA);
auto rowptrC_data = rowptrC.data_ptr<int64_t>();
rowptrC_data[0] = 0;
int64_t rowA_start = 0, rowA_end, rowB_start, rowB_end, cA, cB;
int64_t nnz = 0, row_nnz;
for (auto n = 1; n < rowptrA.numel(); n++) {
rowA_end = rowptrA_data[n];
for (auto eA = rowA_start; eA < rowA_end; eA++) {
cA = colA_data[eA];
row_nnz = rowptrB_data[cA + 1] - rowptrB_data[cA];
}
nnz += row_nnz;
rowptrC_data[n] = nnz;
rowA_start = rowA_end;
}
// Pass 2: Compute CSR entries.
auto colC = torch::empty(nnz, rowptrC.options());
auto colC_data = colC.data_ptr<int64_t>();
torch::optional<torch::Tensor> optional_valueC = torch::nullopt;
if (optional_valueA.has_value())
optional_valueC = torch::empty(nnz, optional_valueA.value().options());
AT_DISPATCH_ALL_TYPES(scalar_type, "spspmm", [&] {
AT_DISPATCH_HAS_VALUE(optional_valueC, [&] {
scalar_t *valA_data = nullptr, *valB_data = nullptr, *valC_data = nullptr;
if (HAS_VALUE) {
valA_data = optional_valueA.value().data_ptr<scalar_t>();
valB_data = optional_valueB.value().data_ptr<scalar_t>();
valC_data = optional_valueC.value().data_ptr<scalar_t>();
}
scalar_t valA;
rowA_start = 0, nnz = 0;
std::vector<scalar_t> vals(K, 0);
for (auto n = 1; n < rowptrA.numel(); n++) {
rowA_end = rowptrA_data[n];
for (auto eA = rowA_start; eA < rowA_end; eA++) {
cA = colA_data[eA];
if (HAS_VALUE)
valA = valA_data[eA];
rowB_start = rowptrB_data[cA], rowB_end = rowptrB_data[cA + 1];
for (auto eB = rowB_start; eB < rowB_end; eB++) {
cB = colB_data[eB];
if (HAS_VALUE)
vals[cB] += valA * valB_data[eB];
else
vals[cB] += 1;
}
}
for (auto k = 0; k < K; k++) {
if (vals[k] != 0) {
colC_data[nnz] = k;
if (HAS_VALUE)
valC_data[nnz] = vals[k];
nnz++;
}
vals[k] = (scalar_t)0;
}
rowA_start = rowA_end;
}
});
});
return std::make_tuple(rowptrC, colC, optional_valueC);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce);
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "compat.cuh"
#define THREADS 256
__global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data,
int64_t M, int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx == 0) {
for (int64_t i = 0; i <= ind_data[0]; i++)
out_data[i] = 0;
} else if (thread_idx < numel) {
for (int64_t i = ind_data[thread_idx - 1]; i < ind_data[thread_idx]; i++)
out_data[i + 1] = thread_idx;
} else if (thread_idx == numel) {
for (int64_t i = ind_data[numel - 1] + 1; i < M + 1; i++)
out_data[i] = numel;
}
}
torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
cudaSetDevice(ind.get_device());
auto out = torch::empty(M + 1, ind.options());
auto ind_data = ind.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
ind2ptr_kernel<<<(ind.numel() + 2 + THREADS - 1) / THREADS, THREADS, 0,
stream>>>(ind_data, out_data, M, ind.numel());
return out;
}
__global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
int64_t E, int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
int64_t idx = ptr_data[thread_idx], next_idx = ptr_data[thread_idx + 1];
for (int64_t i = idx; i < next_idx; i++) {
out_data[i] = thread_idx;
}
}
}
torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
cudaSetDevice(ptr.get_device());
auto out = torch::empty(E, ptr.options());
auto ptr_data = ptr.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
ptr2ind_kernel<<<(ptr.numel() + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
ptr_data, out_data, E, ptr.numel());
return out;
}
#include <torch/script.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
std::string reduce);
torch::Tensor spmm_val_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce);
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
std::string reduce) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
if (value_opt.has_value())
CHECK_CUDA(value_opt.value());
CHECK_CUDA(mat);
return spmm_cuda(rowptr, col, value_opt, mat, reduce);
}
torch::Tensor spmm_val_bw(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
CHECK_CUDA(row);
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(mat);
CHECK_CUDA(grad);
return spmm_val_bw_cuda(row, rowptr, col, mat, grad, reduce);
}
static auto registry =
torch::RegisterOperators("torch_sparse_cuda::spmm", &spmm)
.op("torch_sparse_cuda::spmm_val_bw", &spmm_val_bw);
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "compat.cuh"
#define THREADS 256
#define FULL_MASK 0xffffffff
enum ReductionType { SUM, MEAN, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
return std::numeric_limits<scalar_t>::lowest();
} else {
return (scalar_t)0;
}
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == SUM) {
*address = val;
} else if (REDUCE == MEAN) {
*address = val / (scalar_t)max(count, 1);
} else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else {
*address = (scalar_t)0;
}
}
}
};
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm
template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
const scalar_t *value_data,
const scalar_t *mat_data, scalar_t *out_data,
int64_t *arg_out_data, int B, int M, int N, int K) {
// We ignore blockIdx.y here, because threads
// across `blockIdx.y` are treated equally.
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int row = thread_idx >> 5; // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
int batch_idx = row / M;
// Compute the column index of `mat` in which the thread is operating.
int mat_col_idx = lane_idx + (blockIdx.y << 5);
// Compute the output index (row-major order).
int out_idx = row * K + mat_col_idx;
// Helper arrays for warp communication.
int mat_row, mat_rows[32];
scalar_t val, vals[HAS_VAL ? 32 : 1];
// Do not aggregate/write across the Y-axis (lane_idx < leftover).
int leftover = K - (blockIdx.y << 5);
if (batch_idx < B) {
int row_start = __ldg(rowptr_data + (row % M));
int row_end = __ldg(rowptr_data + (row % M) + 1);
int col_idx = row_start + lane_idx;
scalar_t result = Reducer<scalar_t, REDUCE>::init();
int64_t arg;
// Iterate over all `col` indices in parallel within a warp.
for (int c = row_start; c < row_end; c += 32) {
if (col_idx < row_end) {
// Coalesced memory access into `col` and `val`.
mat_row = __ldg(col_data + col_idx) * K;
if (HAS_VAL)
val = __ldg(value_data + col_idx);
} else {
mat_row = -1;
if (HAS_VAL)
val = (scalar_t)0;
}
col_idx += 32;
#pragma unroll
for (int i = 0; i < 32; i++) {
// Communication between all threads in a warp.
mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
if (HAS_VAL)
vals[i] = __shfl_sync(FULL_MASK, val, i);
}
#pragma unroll
for (int i = 0; i < 32; i++) {
if (lane_idx < leftover && mat_rows[i] != -1) {
// Coalesced memory access into `mat`.
val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
if (HAS_VAL)
val = vals[i] * val;
Reducer<scalar_t, REDUCE>::update(&result, val, &arg, c + i);
}
}
}
if (lane_idx < leftover) {
// Coalesced write into `out`.
Reducer<scalar_t, REDUCE>::write(out_data + out_idx, result,
arg_out_data + out_idx, arg,
row_end - row_start);
}
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
std::string reduce) {
cudaSetDevice(rowptr.get_device());
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(col.dim() == 1, "Input mismatch");
if (value_opt.has_value())
AT_ASSERTM(value_opt.value().dim() == 1);
AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
mat = mat.contiguous();
auto sizes = mat.sizes().vec();
sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto out = torch::empty(sizes, mat.options());
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, col.numel(), rowptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto M = rowptr.numel() - 1;
auto N = mat.size(-2);
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto mat_data = mat.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (value_opt.has_value()) {
auto value_data = value_opt.value().DATA_PTR<scalar_t>();
spmm_kernel<scalar_t, REDUCE, true><<<BLOCKS, THREADS, 0, stream>>>(
rowptr_data, col_data, value_data, mat_data, out_data, arg_out_data,
B, M, N, K);
} else {
spmm_kernel<scalar_t, REDUCE, false><<<BLOCKS, THREADS, 0, stream>>>(
rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, B,
M, N, K);
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, ReductionType REDUCE>
__global__ void
spmm_val_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
const int64_t *col_data, const scalar_t *mat_data,
const scalar_t *grad_data, scalar_t *out_data, int B, int M,
int N, int E, int K) {
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int index_idx = (thread_idx >> 5); // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
if (index_idx < E) {
int row = __ldg(row_data + index_idx);
int col = __ldg(col_data + index_idx);
scalar_t val = (scalar_t)0;
for (int b = 0; b < B; b++) {
for (int k = lane_idx; k < K; k += 32) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
}
#pragma unroll
for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
val += __shfl_down_sync(FULL_MASK, val, i);
}
if (lane_idx == 0) {
if (REDUCE == MEAN) {
int row_start = __ldg(rowptr_data + row);
int row_end = __ldg(rowptr_data + row + 1);
val /= (scalar_t)max(row_end - row_start, 1);
}
out_data[index_idx] = val;
}
}
}
torch::Tensor spmm_val_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
cudaSetDevice(row.get_device());
mat = mat.contiguous();
grad = grad.contiguous();
auto M = grad.size(-2);
auto N = mat.size(-2);
auto E = row.numel();
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);
auto out = torch::zeros(row.numel(), grad.options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw_kernel", [&] {
auto row_data = row.DATA_PTR<int64_t>();
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto mat_data = mat.DATA_PTR<scalar_t>();
auto grad_data = grad.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
spmm_val_bw_kernel<scalar_t, REDUCE><<<BLOCKS, THREADS, 0, stream>>>(
row_data, rowptr_data, col_data, mat_data, grad_data, out_data, B, M,
N, E, K);
});
});
return out;
}
#include <torch/script.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> valueA, torch::Tensor rowptrB,
torch::Tensor colB, torch::optional<torch::Tensor> valueB,
int64_t M, int64_t N, int64_t K);
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> valueA, torch::Tensor rowptrB,
torch::Tensor colB, torch::optional<torch::Tensor> valueB, int64_t M,
int64_t N, int64_t K) {
CHECK_CUDA(rowptrA);
CHECK_CUDA(colA);
if (valueA.has_value())
CHECK_CUDA(valueA.value());
CHECK_CUDA(rowptrB);
CHECK_CUDA(colB);
if (valueB.has_value())
CHECK_CUDA(valueB.value());
return spspmm_cuda(rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K);
}
static auto registry =
torch::RegisterOperators("torch_sparse_cuda::spspmm", &spspmm);
#include <ATen/cuda/CUDAContext.h> #include "spspmm_cuda.h"
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cusparse.h> #include <cusparse.h>
#include "compat.cuh" #include "utils.cuh"
#define AT_DISPATCH_CUSPARSE_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case torch::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseScsrgemm2_bufferSizeExt; \
const auto &cusparsecsrgemm2 = cusparseScsrgemm2; \
return __VA_ARGS__(); \
} \
case torch::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseDcsrgemm2_bufferSizeExt; \
const auto &cusparsecsrgemm2 = cusparseDcsrgemm2; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES(TYPE, ...) \ #define AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES(TYPE, ...) \
[&] { \ [&] { \
...@@ -45,87 +67,120 @@ ...@@ -45,87 +67,120 @@
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA, spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> valueA, torch::Tensor rowptrB, torch::optional<torch::Tensor> optional_valueA,
torch::Tensor colB, torch::optional<torch::Tensor> valueB, torch::Tensor rowptrB, torch::Tensor colB,
int64_t M, int64_t N, int64_t K) { torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce) {
CHECK_CUDA(rowptrA);
CHECK_CUDA(colA);
if (optional_valueA.has_value())
CHECK_CUDA(optional_valueA.value());
CHECK_CUDA(rowptrB);
CHECK_CUDA(colB);
if (optional_valueB.has_value())
CHECK_CUDA(optional_valueB.value());
cudaSetDevice(rowptrA.get_device()); cudaSetDevice(rowptrA.get_device());
cusparseMatDescr_t descr = 0; CHECK_INPUT(rowptrA.dim() == 1);
cusparseCreateMatDescr(&descr); CHECK_INPUT(colA.dim() == 1);
if (optional_valueA.has_value()) {
CHECK_INPUT(optional_valueA.value().dim() == 1);
CHECK_INPUT(optional_valueA.value().size(0) == colA.size(0));
}
CHECK_INPUT(rowptrB.dim() == 1);
CHECK_INPUT(colB.dim() == 1);
if (optional_valueB.has_value()) {
CHECK_INPUT(optional_valueB.value().dim() == 1);
CHECK_INPUT(optional_valueB.value().size(0) == colB.size(0));
}
if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA =
torch::ones(colA.numel(), optional_valueB.value().options());
if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB =
torch::ones(colB.numel(), optional_valueA.value().options());
auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value())
scalar_type = optional_valueA.value().scalar_type();
auto handle = at::cuda::getCurrentCUDASparseHandle(); auto handle = at::cuda::getCurrentCUDASparseHandle();
cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST);
cusparseMatDescr_t descr;
cusparseCreateMatDescr(&descr);
rowptrA = rowptrA.toType(torch::kInt), colA = colA.toType(torch::kInt); rowptrA = rowptrA.toType(torch::kInt);
rowptrB = rowptrB.toType(torch::kInt), colB = colB.toType(torch::kInt); colA = colA.toType(torch::kInt);
rowptrB = rowptrB.toType(torch::kInt);
colB = colB.toType(torch::kInt);
auto rowptrA_data = rowptrA.DATA_PTR<int>(), colA_data = colA.DATA_PTR<int>(); int64_t M = rowptrA.numel() - 1, N = rowptrB.numel() - 1;
auto rowptrB_data = rowptrB.DATA_PTR<int>(), colB_data = colB.DATA_PTR<int>(); auto rowptrA_data = rowptrA.data_ptr<int>();
auto colA_data = colA.data_ptr<int>();
auto rowptrB_data = rowptrB.data_ptr<int>();
auto colB_data = colB.data_ptr<int>();
torch::Tensor rowptrC, colC;
torch::optional<torch::Tensor> optional_valueC = torch::nullopt;
int nnzC;
int *nnzTotalDevHostPtr = &nnzC;
// Step 1: Create an opaque structure.
csrgemm2Info_t info = NULL; csrgemm2Info_t info = NULL;
cusparseCreateCsrgemm2Info(&info); cusparseCreateCsrgemm2Info(&info);
auto scalar_type = torch::ScalarType::Float; // Step 2: Allocate buffer for `csrgemm2Nnz` and `csrgemm2`.
if (valueA.has_value())
scalar_type = valueA.value().scalar_type();
if (valueB.has_value())
scalar_type = valueB.value().scalar_type();
size_t bufferSize; size_t bufferSize;
AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES(scalar_type, [&] { AT_DISPATCH_CUSPARSE_TYPES(scalar_type, [&] {
scalar_t alpha = (scalar_t)1; scalar_t alpha = (scalar_t)1.0;
cusparsecsrgemm2_bufferSizeExt(handle, M, N, K, &alpha, descr, colA.numel(), cusparsecsrgemm2_bufferSizeExt(handle, M, N, K, &alpha, descr, colA.numel(),
rowptrA_data, colA_data, descr, colB.numel(), rowptrA_data, colA_data, descr, colB.numel(),
rowptrB_data, colB_data, NULL, descr, 0, rowptrB_data, colB_data, NULL, descr, 0,
NULL, NULL, info, &bufferSize); NULL, NULL, info, &bufferSize);
});
void *buffer = NULL; void *buffer = NULL;
cudaMalloc(&buffer, bufferSize); cudaMalloc(&buffer, bufferSize);
int nnzC; // Step 3: Compute CSR row pointer.
auto rowptrC = torch::empty(M + 1, rowptrA.options()); rowptrC = torch::empty(M + 1, rowptrA.options());
auto rowptrC_data = rowptrC.DATA_PTR<int>(); auto rowptrC_data = rowptrC.data_ptr<int>();
cusparseXcsrgemm2Nnz(handle, M, N, K, descr, colA.numel(), rowptrA_data, cusparseXcsrgemm2Nnz(handle, M, N, K, descr, colA.numel(), rowptrA_data,
colA_data, descr, colB.numel(), rowptrB_data, colB_data, colA_data, descr, colB.numel(), rowptrB_data,
descr, 0, NULL, NULL, descr, rowptrC_data, &nnzC, info, colB_data, descr, 0, NULL, NULL, descr, rowptrC_data,
buffer); nnzTotalDevHostPtr, info, buffer);
auto colC = torch::empty(nnzC, colA.options()); // Step 4: Compute CSR entries.
auto colC_data = colC.DATA_PTR<int>(); colC = torch::empty(nnzC, rowptrC.options());
auto colC_data = colC.data_ptr<int>();
if (!valueA.has_value() && valueB.has_value())
valueA = torch::ones_like(valueB.value()); if (optional_valueA.has_value())
optional_valueC = torch::empty(nnzC, optional_valueA.value().options());
if (!valueB.has_value() && valueA.has_value())
valueB = torch::ones_like(valueA.value()); scalar_t *valA_data = NULL, *valB_data = NULL, *valC_data = NULL;
if (optional_valueA.has_value()) {
torch::optional<torch::Tensor> valueC = torch::nullopt; valA_data = optional_valueA.value().data_ptr<scalar_t>();
if (valueA.has_value()) valB_data = optional_valueB.value().data_ptr<scalar_t>();
valueC = torch::empty(nnzC, valueA.value().options()); valC_data = optional_valueC.value().data_ptr<scalar_t>();
}
AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES(scalar_type, [&] {
scalar_t alpha = (scalar_t)1; cusparsecsrgemm2(handle, M, N, K, &alpha, descr, colA.numel(), valA_data,
rowptrA_data, colA_data, descr, colB.numel(), valB_data,
scalar_t *valueA_data = NULL;
if (valueA.has_value())
valueA_data = valueA.value().DATA_PTR<scalar_t>();
scalar_t *valueB_data = NULL;
if (valueB.has_value())
valueB_data = valueB.value().DATA_PTR<scalar_t>();
scalar_t *valueC_data = NULL;
if (valueC.has_value())
valueC_data = valueC.value().DATA_PTR<scalar_t>();
cusparsecsrgemm2(handle, M, N, K, &alpha, descr, colA.numel(), valueA_data,
rowptrA_data, colA_data, descr, colB.numel(), valueB_data,
rowptrB_data, colB_data, NULL, descr, 0, NULL, NULL, NULL, rowptrB_data, colB_data, NULL, descr, 0, NULL, NULL, NULL,
descr, valueC_data, rowptrC_data, colC_data, info, buffer); descr, valC_data, rowptrC_data, colC_data, info, buffer);
cudaFree(buffer);
}); });
// Step 5: Destroy the opaque structure.
cusparseDestroyCsrgemm2Info(info);
rowptrC = rowptrC.toType(torch::kLong); rowptrC = rowptrC.toType(torch::kLong);
colC = colC.toType(torch::kLong); colC = colC.toType(torch::kLong);
return std::make_tuple(rowptrC, colC, valueC); return std::make_tuple(rowptrC, colC, optional_valueC);
} }
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce);
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src);
std::tuple<at::Tensor, at::Tensor> unique(at::Tensor src) {
CHECK_CUDA(src);
return unique_cuda(src);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("unique", &unique, "Unique (CUDA)");
}
#include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void unique_cuda_kernel(scalar_t *__restrict__ src, bool *mask,
size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
if (i == 0 || src[i] != src[i - 1]) {
mask[i] = true;
}
}
}
std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
cudaSetDevice(src.get_device());
at::Tensor perm;
std::tie(src, perm) = src.sort();
auto mask = at::zeros(src.numel(), src.options().dtype(at::kBool));
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] {
unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
src.DATA_PTR<scalar_t>(), mask.DATA_PTR<bool>(), src.numel());
});
src = src.masked_select(mask);
perm = perm.masked_select(mask);
return std::make_tuple(src, perm);
}
#include <torch/script.h>
#include "cpu/spspmm_cpu.h"
#ifdef WITH_CUDA
#include "cuda/spspmm_cuda.h"
#endif
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_sum(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K) {
if (rowptrA.device().is_cuda()) {
#ifdef WITH_CUDA
return spspmm_cuda(rowptrA, colA, optional_valueA, rowptrB, colB,
optional_valueB, K, "sum");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spspmm_cpu(rowptrA, colA, optional_valueA, rowptrB, colB,
optional_valueB, K, "sum");
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::spspmm_sum", &spspmm_sum);
...@@ -41,21 +41,28 @@ def test_spmm(dtype, device, reduce): ...@@ -41,21 +41,28 @@ def test_spmm(dtype, device, reduce):
out.backward(grad_out) out.backward(grad_out)
assert torch.allclose(expected, out) assert torch.allclose(expected, out)
assert torch.allclose(expected_grad_value, value.grad) assert torch.allclose(expected_grad_value, value.grad, atol=1e-6)
assert torch.allclose(expected_grad_other, other.grad) assert torch.allclose(expected_grad_other, other.grad, atol=1e-6)
# @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
# def test_spspmm(dtype, device): def test_spspmm(dtype, device):
# src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype, src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
# device=device) device=device)
# src = SparseTensor.from_dense(src) src = SparseTensor.from_dense(src)
# out = src @ src out = matmul(src, src)
# assert out.size() == (3, 3) assert out.sizes() == [3, 3]
# assert out.has_value() assert out.has_value()
rowptr, col, value = out.csr()
assert rowptr.tolist() == [0, 1, 2, 3]
assert col.tolist() == [0, 1, 2]
assert value.tolist() == [1, 1, 1]
# src.set_value_(None) src.set_value_(None)
# out = src @ src out = matmul(src, src)
# assert out.size() == (3, 3) assert out.sizes() == [3, 3]
# assert not out.has_value() assert not out.has_value()
rowptr, col, value = out.csr()
assert rowptr.tolist() == [0, 1, 2, 3]
assert col.tolist() == [0, 1, 2]
...@@ -36,6 +36,17 @@ except OSError: ...@@ -36,6 +36,17 @@ except OSError:
raise ImportError raise ImportError
return mat, mat return mat, mat
torch.ops.torch_sparse.spmm_sum = spmm_sum_placeholder
torch.ops.torch_sparse.spmm_mean = spmm_mean_placeholder
torch.ops.torch_sparse.spmm_min = spmm_min_max_placeholder
torch.ops.torch_sparse.spmm_max = spmm_min_max_placeholder
try:
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_spspmm.so'))
except OSError:
warnings.warn('Failed to load `spspmm` binaries.')
def spspmm_sum_placeholder( def spspmm_sum_placeholder(
rowptrA: torch.Tensor, colA: torch.Tensor, rowptrA: torch.Tensor, colA: torch.Tensor,
valueA: Optional[torch.Tensor], rowptrB: torch.Tensor, valueA: Optional[torch.Tensor], rowptrB: torch.Tensor,
...@@ -44,10 +55,6 @@ except OSError: ...@@ -44,10 +55,6 @@ except OSError:
raise ImportError raise ImportError
return rowptrA, colA, valueA return rowptrA, colA, valueA
torch.ops.torch_sparse.spmm_sum = spmm_sum_placeholder
torch.ops.torch_sparse.spmm_mean = spmm_mean_placeholder
torch.ops.torch_sparse.spmm_min = spmm_min_max_placeholder
torch.ops.torch_sparse.spmm_max = spmm_min_max_placeholder
torch.ops.torch_sparse.spspmm_sum = spspmm_sum_placeholder torch.ops.torch_sparse.spspmm_sum = spspmm_sum_placeholder
...@@ -129,6 +136,7 @@ def spmm(src: SparseTensor, other: torch.Tensor, ...@@ -129,6 +136,7 @@ def spmm(src: SparseTensor, other: torch.Tensor,
@torch.jit.script @torch.jit.script
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor: def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
assert src.sparse_size(1) == other.sparse_size(0)
rowptrA, colA, valueA = src.csr() rowptrA, colA, valueA = src.csr()
rowptrB, colB, valueB = other.csr() rowptrB, colB, valueB = other.csr()
M, K = src.sparse_size(0), other.sparse_size(1) M, K = src.sparse_size(0), other.sparse_size(1)
......
...@@ -311,34 +311,44 @@ class SparseTensor(object): ...@@ -311,34 +311,44 @@ class SparseTensor(object):
return torch.is_floating_point(self.options()) return torch.is_floating_point(self.options())
def bfloat16(self): def bfloat16(self):
return self.type_as(torch.tensor(0, dtype=torch.bfloat16)) return self.type_as(
torch.tensor(0, dtype=torch.bfloat16, device=self.device()))
def bool(self): def bool(self):
return self.type_as(torch.tensor(0, dtype=torch.bool)) return self.type_as(
torch.tensor(0, dtype=torch.bool, device=self.device()))
def byte(self): def byte(self):
return self.type_as(torch.tensor(0, dtype=torch.uint8)) return self.type_as(
torch.tensor(0, dtype=torch.uint8, device=self.device()))
def char(self): def char(self):
return self.type_as(torch.tensor(0, dtype=torch.int8)) return self.type_as(
torch.tensor(0, dtype=torch.int8, device=self.device()))
def half(self): def half(self):
return self.type_as(torch.tensor(0, dtype=torch.half)) return self.type_as(
torch.tensor(0, dtype=torch.half, device=self.device()))
def float(self): def float(self):
return self.type_as(torch.tensor(0, dtype=torch.float)) return self.type_as(
torch.tensor(0, dtype=torch.float, device=self.device()))
def double(self): def double(self):
return self.type_as(torch.tensor(0, dtype=torch.double)) return self.type_as(
torch.tensor(0, dtype=torch.double, device=self.device()))
def short(self): def short(self):
return self.type_as(torch.tensor(0, dtype=torch.short)) return self.type_as(
torch.tensor(0, dtype=torch.short, device=self.device()))
def int(self): def int(self):
return self.type_as(torch.tensor(0, dtype=torch.int)) return self.type_as(
torch.tensor(0, dtype=torch.int, device=self.device()))
def long(self): def long(self):
return self.type_as(torch.tensor(0, dtype=torch.long)) return self.type_as(
torch.tensor(0, dtype=torch.long, device=self.device()))
# Conversions ############################################################# # Conversions #############################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment