Unverified Commit 7671fcb0 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #33 from rusty1s/adj

[WIP] SparseTensor Format
parents 1fb5fa4f 704ad420
#include "diag_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
__global__ void non_diag_mask_kernel(const int64_t *row_data,
const int64_t *col_data, bool *out_data,
int64_t N, int64_t k, int64_t num_diag,
int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
int64_t r = row_data[thread_idx], c = col_data[thread_idx];
if (k < 0) {
if (r + k < 0) {
out_data[thread_idx] = true;
} else if (r + k >= N) {
out_data[thread_idx + num_diag] = true;
} else if (r + k > c) {
out_data[thread_idx + r + k] = true;
} else if (r + k < c) {
out_data[thread_idx + r + k + 1] = true;
}
} else {
if (r + k >= N) {
out_data[thread_idx + num_diag] = true;
} else if (r + k > c) {
out_data[thread_idx + r] = true;
} else if (r + k < c) {
out_data[thread_idx + r + 1] = true;
}
}
}
}
torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t M, int64_t N, int64_t k) {
CHECK_CUDA(row);
CHECK_CUDA(col);
cudaSetDevice(row.get_device());
auto E = row.size(0);
auto num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto row_data = row.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
auto mask_data = mask.data_ptr<bool>();
auto stream = at::cuda::getCurrentCUDAStream();
non_diag_mask_kernel<<<(E + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
row_data, col_data, mask_data, N, k, num_diag, E);
return mask;
}
#pragma once
#include <torch/extension.h>
torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t M, int64_t N, int64_t k);
#pragma once
#include <limits>
#include <map>
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
{"div", DIV}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
const ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
const ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
};
#include "spmm_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 256
#define FULL_MASK 0xffffffff
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm
template <typename scalar_t, ReductionType REDUCE, bool HAS_VALUE>
__global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
const scalar_t *value_data,
const scalar_t *mat_data, scalar_t *out_data,
int64_t *arg_out_data, int B, int M, int N, int K) {
// We ignore blockIdx.y here, because threads
// across `blockIdx.y` are treated equally.
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int row = thread_idx >> 5; // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
int batch_idx = row / M;
// Compute the column index of `mat` in which the thread is operating.
int mat_col_idx = lane_idx + (blockIdx.y << 5);
// Compute the output index (row-major order).
int out_idx = row * K + mat_col_idx;
// Helper arrays for warp communication.
int mat_row, mat_rows[32];
scalar_t val, vals[HAS_VALUE ? 32 : 1];
// Do not aggregate/write across the Y-axis (lane_idx < leftover).
int leftover = K - (blockIdx.y << 5);
if (batch_idx < B) {
int row_start = __ldg(rowptr_data + (row % M));
int row_end = __ldg(rowptr_data + (row % M) + 1);
int col_idx = row_start + lane_idx;
scalar_t result = Reducer<scalar_t, REDUCE>::init();
int64_t arg;
// Iterate over all `col` indices in parallel within a warp.
for (int c = row_start; c < row_end; c += 32) {
if (col_idx < row_end) {
// Coalesced memory access into `col` and `val`.
mat_row = __ldg(col_data + col_idx) * K;
if (HAS_VALUE)
val = __ldg(value_data + col_idx);
} else {
mat_row = -1;
if (HAS_VALUE)
val = (scalar_t)0;
}
col_idx += 32;
#pragma unroll
for (int i = 0; i < 32; i++) {
// Communication between all threads in a warp.
mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
if (HAS_VALUE)
vals[i] = __shfl_sync(FULL_MASK, val, i);
}
#pragma unroll
for (int i = 0; i < 32; i++) {
if (lane_idx < leftover && mat_rows[i] != -1) {
// Coalesced memory access into `mat`.
val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
if (HAS_VALUE)
val = vals[i] * val;
Reducer<scalar_t, REDUCE>::update(&result, val, &arg, c + i);
}
}
}
if (lane_idx < leftover) {
// Coalesced write into `out`.
Reducer<scalar_t, REDUCE>::write(out_data + out_idx, result,
arg_out_data + out_idx, arg,
row_end - row_start);
}
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
if (optional_value.has_value())
CHECK_CUDA(optional_value.value());
CHECK_CUDA(mat);
cudaSetDevice(rowptr.get_device());
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
if (optional_value.has_value()) {
CHECK_INPUT(optional_value.value().dim() == 1);
CHECK_INPUT(optional_value.value().size(0) == col.size(0));
}
CHECK_INPUT(mat.dim() >= 2);
mat = mat.contiguous();
auto sizes = mat.sizes().vec();
sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto out = torch::empty(sizes, mat.options());
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, col.numel(), rowptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto M = rowptr.numel() - 1;
auto N = mat.size(-2);
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (optional_value.has_value()) {
auto value_data = optional_value.value().data_ptr<scalar_t>();
spmm_kernel<scalar_t, REDUCE, true><<<BLOCKS, THREADS, 0, stream>>>(
rowptr_data, col_data, value_data, mat_data, out_data, arg_out_data,
B, M, N, K);
} else {
spmm_kernel<scalar_t, REDUCE, false><<<BLOCKS, THREADS, 0, stream>>>(
rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, B,
M, N, K);
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, ReductionType REDUCE>
__global__ void
spmm_value_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
const int64_t *col_data, const scalar_t *mat_data,
const scalar_t *grad_data, scalar_t *out_data, int B,
int M, int N, int E, int K) {
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int index_idx = (thread_idx >> 5); // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
if (index_idx < E) {
int row = __ldg(row_data + index_idx);
int col = __ldg(col_data + index_idx);
scalar_t val = (scalar_t)0;
for (int b = 0; b < B; b++) {
for (int k = lane_idx; k < K; k += 32) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
}
#pragma unroll
for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
val += __shfl_down_sync(FULL_MASK, val, i);
}
if (lane_idx == 0) {
if (REDUCE == MEAN) {
int row_start = __ldg(rowptr_data + row);
int row_end = __ldg(rowptr_data + row + 1);
val /= (scalar_t)max(row_end - row_start, 1);
}
out_data[index_idx] = val;
}
}
}
torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
CHECK_CUDA(row);
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(mat);
CHECK_CUDA(grad);
cudaSetDevice(row.get_device());
mat = mat.contiguous();
grad = grad.contiguous();
auto M = grad.size(-2);
auto N = mat.size(-2);
auto E = row.numel();
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);
auto out = torch::zeros(row.numel(), grad.options());
auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw_kernel", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
spmm_value_bw_kernel<scalar_t, REDUCE><<<BLOCKS, THREADS, 0, stream>>>(
row_data, rowptr_data, col_data, mat_data, grad_data, out_data, B, M,
N, E, K);
});
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce);
torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce);
#include "spspmm_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include <cusparse.h>
#include "utils.cuh"
#define AT_DISPATCH_CUSPARSE_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case torch::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseScsrgemm2_bufferSizeExt; \
const auto &cusparsecsrgemm2 = cusparseScsrgemm2; \
return __VA_ARGS__(); \
} \
case torch::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseDcsrgemm2_bufferSizeExt; \
const auto &cusparsecsrgemm2 = cusparseDcsrgemm2; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce) {
CHECK_CUDA(rowptrA);
CHECK_CUDA(colA);
if (optional_valueA.has_value())
CHECK_CUDA(optional_valueA.value());
CHECK_CUDA(rowptrB);
CHECK_CUDA(colB);
if (optional_valueB.has_value())
CHECK_CUDA(optional_valueB.value());
cudaSetDevice(rowptrA.get_device());
CHECK_INPUT(rowptrA.dim() == 1);
CHECK_INPUT(colA.dim() == 1);
if (optional_valueA.has_value()) {
CHECK_INPUT(optional_valueA.value().dim() == 1);
CHECK_INPUT(optional_valueA.value().size(0) == colA.size(0));
}
CHECK_INPUT(rowptrB.dim() == 1);
CHECK_INPUT(colB.dim() == 1);
if (optional_valueB.has_value()) {
CHECK_INPUT(optional_valueB.value().dim() == 1);
CHECK_INPUT(optional_valueB.value().size(0) == colB.size(0));
}
if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA =
torch::ones(colA.numel(), optional_valueB.value().options());
if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB =
torch::ones(colB.numel(), optional_valueA.value().options());
auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value())
scalar_type = optional_valueA.value().scalar_type();
auto handle = at::cuda::getCurrentCUDASparseHandle();
cusparseMatDescr_t descr;
cusparseCreateMatDescr(&descr);
rowptrA = rowptrA.toType(torch::kInt);
colA = colA.toType(torch::kInt);
rowptrB = rowptrB.toType(torch::kInt);
colB = colB.toType(torch::kInt);
int64_t M = rowptrA.numel() - 1, N = rowptrB.numel() - 1;
auto rowptrA_data = rowptrA.data_ptr<int>();
auto colA_data = colA.data_ptr<int>();
auto rowptrB_data = rowptrB.data_ptr<int>();
auto colB_data = colB.data_ptr<int>();
torch::Tensor rowptrC, colC;
torch::optional<torch::Tensor> optional_valueC = torch::nullopt;
int nnzC;
int *nnzTotalDevHostPtr = &nnzC;
// Step 1: Create an opaque structure.
csrgemm2Info_t info = NULL;
cusparseCreateCsrgemm2Info(&info);
// Step 2: Allocate buffer for `csrgemm2Nnz` and `csrgemm2`.
size_t bufferSize;
AT_DISPATCH_CUSPARSE_TYPES(scalar_type, [&] {
scalar_t alpha = (scalar_t)1.0;
cusparsecsrgemm2_bufferSizeExt(handle, M, N, K, &alpha, descr, colA.numel(),
rowptrA_data, colA_data, descr, colB.numel(),
rowptrB_data, colB_data, NULL, descr, 0,
NULL, NULL, info, &bufferSize);
void *buffer = NULL;
cudaMalloc(&buffer, bufferSize);
// Step 3: Compute CSR row pointer.
rowptrC = torch::empty(M + 1, rowptrA.options());
auto rowptrC_data = rowptrC.data_ptr<int>();
cusparseXcsrgemm2Nnz(handle, M, N, K, descr, colA.numel(), rowptrA_data,
colA_data, descr, colB.numel(), rowptrB_data,
colB_data, descr, 0, NULL, NULL, descr, rowptrC_data,
nnzTotalDevHostPtr, info, buffer);
// Step 4: Compute CSR entries.
colC = torch::empty(nnzC, rowptrC.options());
auto colC_data = colC.data_ptr<int>();
if (optional_valueA.has_value())
optional_valueC = torch::empty(nnzC, optional_valueA.value().options());
scalar_t *valA_data = NULL, *valB_data = NULL, *valC_data = NULL;
if (optional_valueA.has_value()) {
valA_data = optional_valueA.value().data_ptr<scalar_t>();
valB_data = optional_valueB.value().data_ptr<scalar_t>();
valC_data = optional_valueC.value().data_ptr<scalar_t>();
}
cusparsecsrgemm2(handle, M, N, K, &alpha, descr, colA.numel(), valA_data,
rowptrA_data, colA_data, descr, colB.numel(), valB_data,
rowptrB_data, colB_data, NULL, descr, 0, NULL, NULL, NULL,
descr, valC_data, rowptrC_data, colC_data, info, buffer);
cudaFree(buffer);
});
// Step 5: Destroy the opaque structure.
cusparseDestroyCsrgemm2Info(info);
rowptrC = rowptrC.toType(torch::kLong);
colC = colC.toType(torch::kLong);
return std::make_tuple(rowptrC, colC, optional_valueC);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce);
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#include <torch/script.h>
#include "cpu/diag_cpu.h"
#ifdef WITH_CUDA
#include "cuda/diag_cuda.h"
#endif
torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M,
int64_t N, int64_t k) {
if (row.device().is_cuda()) {
#ifdef WITH_CUDA
return non_diag_mask_cuda(row, col, M, N, k);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return non_diag_mask_cpu(row, col, M, N, k);
}
}
static auto registry = torch::RegisterOperators().op(
"torch_sparse::non_diag_mask", &non_diag_mask);
#include <torch/script.h>
#include "cpu/spmm_cpu.h"
#ifdef WITH_CUDA
#include "cuda/spmm_cuda.h"
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_fw(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return spmm_cuda(rowptr, col, optional_value, mat, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spmm_cpu(rowptr, col, optional_value, mat, reduce);
}
}
torch::Tensor spmm_value_bw(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
if (row.device().is_cuda()) {
#ifdef WITH_CUDA
return spmm_value_bw_cuda(row, rowptr, col, mat, grad, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spmm_value_bw_cpu(row, rowptr, col, mat, grad, reduce);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SPMMSum : public torch::autograd::Function<SPMMSum> {
public:
static variable_list forward(AutogradContext *ctx,
torch::optional<Variable> opt_row,
Variable rowptr, Variable col, Variable value,
torch::optional<Variable> opt_colptr,
torch::optional<Variable> opt_csr2csc,
Variable mat, bool has_value) {
if (has_value && torch::autograd::any_variable_requires_grad({value})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
}
if (torch::autograd::any_variable_requires_grad({mat})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
AT_ASSERTM(opt_colptr.has_value(), "Argument `colptr` is missing");
AT_ASSERTM(opt_csr2csc.has_value(), "Argument `csr2csc` is missing");
}
auto row = opt_row.has_value() ? opt_row.value() : col;
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : col;
auto csr2csc = opt_csr2csc.has_value() ? opt_csr2csc.value() : col;
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value;
auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "sum"));
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward({row, rowptr, col, value, colptr, csr2csc, mat});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3],
colptr = saved[4], csr2csc = saved[5], mat = saved[6];
auto grad_value = Variable();
if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "sum");
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value.index_select(0, csr2csc);
grad_mat = std::get<0>(spmm_fw(colptr, row.index_select(0, csr2csc),
opt_value, grad_out, "sum"));
}
return {Variable(), Variable(), Variable(), grad_value,
Variable(), Variable(), grad_mat, Variable()};
}
};
class SPMMMean : public torch::autograd::Function<SPMMMean> {
public:
static variable_list forward(AutogradContext *ctx,
torch::optional<Variable> opt_row,
Variable rowptr, Variable col, Variable value,
torch::optional<Variable> opt_rowcount,
torch::optional<Variable> opt_colptr,
torch::optional<Variable> opt_csr2csc,
Variable mat, bool has_value) {
if (has_value && torch::autograd::any_variable_requires_grad({value})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
}
if (torch::autograd::any_variable_requires_grad({mat})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
AT_ASSERTM(opt_rowcount.has_value(), "Argument `rowcount` is missing");
AT_ASSERTM(opt_colptr.has_value(), "Argument `colptr` is missing");
AT_ASSERTM(opt_csr2csc.has_value(), "Argument `csr2csc` is missing");
}
auto row = opt_row.has_value() ? opt_row.value() : col;
auto rowcount = opt_rowcount.has_value() ? opt_rowcount.value() : col;
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : col;
auto csr2csc = opt_csr2csc.has_value() ? opt_csr2csc.value() : col;
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value;
auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "mean"));
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward(
{row, rowptr, col, value, rowcount, colptr, csr2csc, mat});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3],
rowcount = saved[4], colptr = saved[5], csr2csc = saved[6],
mat = saved[7];
auto grad_value = Variable();
if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "mean");
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
row = row.index_select(0, csr2csc);
rowcount = rowcount.toType(mat.scalar_type()).index_select(0, row);
rowcount.clamp_(1);
if (has_value > 0)
rowcount = value.index_select(0, csr2csc).div(rowcount);
else
rowcount.pow_(-1);
grad_mat = std::get<0>(spmm_fw(colptr, row, rowcount, grad_out, "sum"));
}
return {Variable(), Variable(), Variable(), grad_value, Variable(),
Variable(), Variable(), grad_mat, Variable()};
}
};
class SPMMMin : public torch::autograd::Function<SPMMMin> {
public:
static variable_list forward(AutogradContext *ctx, Variable rowptr,
Variable col, Variable value, Variable mat,
bool has_value) {
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value;
auto result = spmm_fw(rowptr, col, opt_value, mat, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward({col, value, mat, arg_out});
ctx->mark_non_differentiable({arg_out});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto col = saved[0], value = saved[1], mat = saved[2], arg_out = saved[3];
auto invalid_arg_mask = arg_out == col.size(0);
arg_out = arg_out.masked_fill(invalid_arg_mask, 0);
auto grad_value = Variable();
if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
auto ind = col.index_select(0, arg_out.flatten()).view_as(arg_out);
auto out = mat.gather(-2, ind);
out.mul_(grad_out);
out.masked_fill_(invalid_arg_mask, 0);
grad_value = torch::zeros_like(value);
grad_value.scatter_add_(0, arg_out.flatten(), out.flatten());
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
if (has_value > 0) {
value = value.index_select(0, arg_out.flatten()).view_as(arg_out);
value.mul_(grad_out);
} else
value = grad_out;
value.masked_fill_(invalid_arg_mask, 0);
auto ind = col.index_select(0, arg_out.flatten()).view_as(arg_out);
grad_mat = torch::zeros_like(mat);
grad_mat.scatter_add_(-2, ind, value);
}
return {Variable(), Variable(), grad_value, grad_mat, Variable()};
}
};
class SPMMMax : public torch::autograd::Function<SPMMMax> {
public:
static variable_list forward(AutogradContext *ctx, Variable rowptr,
Variable col, Variable value, Variable mat,
bool has_value) {
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value;
auto result = spmm_fw(rowptr, col, opt_value, mat, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward({col, value, mat, arg_out});
ctx->mark_non_differentiable({arg_out});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto col = saved[0], value = saved[1], mat = saved[2], arg_out = saved[3];
auto invalid_arg_mask = arg_out == col.size(0);
arg_out = arg_out.masked_fill(invalid_arg_mask, 0);
auto grad_value = Variable();
if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
auto ind = col.index_select(0, arg_out.flatten()).view_as(arg_out);
auto out = mat.gather(-2, ind);
out.mul_(grad_out);
out.masked_fill_(invalid_arg_mask, 0);
grad_value = torch::zeros_like(value);
grad_value.scatter_add_(0, arg_out.flatten(), out.flatten());
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
if (has_value > 0) {
value = value.index_select(0, arg_out.flatten()).view_as(arg_out);
value.mul_(grad_out);
} else
value = grad_out;
value.masked_fill_(invalid_arg_mask, 0);
auto ind = col.index_select(0, arg_out.flatten()).view_as(arg_out);
grad_mat = torch::zeros_like(mat);
grad_mat.scatter_add_(-2, ind, value);
}
return {Variable(), Variable(), grad_value, grad_mat, Variable()};
}
};
torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row,
torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value,
torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : col;
return SPMMSum::apply(opt_row, rowptr, col, value, opt_colptr, opt_csr2csc,
mat, opt_value.has_value())[0];
}
torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value,
torch::optional<torch::Tensor> opt_rowcount,
torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : col;
return SPMMMean::apply(opt_row, rowptr, col, value, opt_rowcount, opt_colptr,
opt_csr2csc, mat, opt_value.has_value())[0];
}
std::tuple<torch::Tensor, torch::Tensor>
spmm_min(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value, torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : col;
auto result = SPMMMin::apply(rowptr, col, value, mat, opt_value.has_value());
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
spmm_max(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value, torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : col;
auto result = SPMMMax::apply(rowptr, col, value, mat, opt_value.has_value());
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators()
.op("torch_sparse::spmm_sum", &spmm_sum)
.op("torch_sparse::spmm_mean", &spmm_mean)
.op("torch_sparse::spmm_min", &spmm_min)
.op("torch_sparse::spmm_max", &spmm_max);
#include <torch/script.h>
#include "cpu/spspmm_cpu.h"
#ifdef WITH_CUDA
#include "cuda/spspmm_cuda.h"
#endif
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_sum(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K) {
if (rowptrA.device().is_cuda()) {
#ifdef WITH_CUDA
return spspmm_cuda(rowptrA, colA, optional_valueA, rowptrB, colB,
optional_valueB, K, "sum");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spspmm_cpu(rowptrA, colA, optional_valueA, rowptrB, colB,
optional_valueB, K, "sum");
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::spspmm_sum", &spspmm_sum);
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, size_t m, size_t k, size_t n);
at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, size_t rowA_max, size_t rowB_max);
std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor valueA,
at::Tensor indexB, at::Tensor valueB,
size_t m, size_t k, size_t n) {
CHECK_CUDA(indexA);
CHECK_CUDA(valueA);
CHECK_CUDA(indexB);
CHECK_CUDA(valueB);
return spspmm_cuda(indexA, valueA, indexB, valueB, m, k, n);
}
at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
size_t rowB_max) {
CHECK_CUDA(index);
CHECK_CUDA(indexA);
CHECK_CUDA(valueA);
CHECK_CUDA(indexB);
CHECK_CUDA(valueB);
return spspmm_bw_cuda(index, indexA, valueA, indexB, valueB, rowA_max,
rowB_max);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spspmm", &spspmm, "Sparse-Sparse Matrix Multiplication (CUDA)");
m.def("spspmm_bw", &spspmm_bw,
"Sparse-Sparse Matrix Multiplication Backward (CUDA)");
}
#include <ATen/ATen.h>
#include <cusparse.h>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define CSRGEMM(TYPE, ...) \
[&] { \
const auto &the_type = TYPE; \
(void)the_type; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
return cusparseScsrgemm(__VA_ARGS__); \
} \
case at::ScalarType::Double: { \
using scalar_t = double; \
return cusparseDcsrgemm(__VA_ARGS__); \
} \
default: \
AT_ERROR("Not implemented for '", toString(_st), "'"); \
} \
}()
static cusparseHandle_t cusparse_handle = 0;
static void init_cusparse() {
if (cusparse_handle == 0) {
cusparseStatus_t status = cusparseCreate(&cusparse_handle);
}
}
std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, size_t m, size_t k, size_t n) {
cudaSetDevice(indexA.get_device());
init_cusparse();
indexA = indexA.contiguous();
valueA = valueA.contiguous();
indexB = indexB.contiguous();
valueB = valueB.contiguous();
auto nnzA = valueA.size(0);
auto nnzB = valueB.size(0);
indexA = indexA.toType(at::kInt);
indexB = indexB.toType(at::kInt);
// Convert A to CSR format.
auto row_ptrA = at::empty(m + 1, indexA.options());
cusparseXcoo2csr(cusparse_handle, indexA[0].DATA_PTR<int>(), nnzA, k,
row_ptrA.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto colA = indexA[1];
cudaMemcpy(row_ptrA.DATA_PTR<int>() + m, &nnzA, sizeof(int),
cudaMemcpyHostToDevice);
// Convert B to CSR format.
auto row_ptrB = at::empty(k + 1, indexB.options());
cusparseXcoo2csr(cusparse_handle, indexB[0].DATA_PTR<int>(), nnzB, k,
row_ptrB.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto colB = indexB[1];
cudaMemcpy(row_ptrB.DATA_PTR<int>() + k, &nnzB, sizeof(int),
cudaMemcpyHostToDevice);
cusparseMatDescr_t descr = 0;
cusparseCreateMatDescr(&descr);
cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
int nnzC;
auto row_ptrC = at::empty(m + 1, indexB.options());
cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
row_ptrA.DATA_PTR<int>(), colA.DATA_PTR<int>(), descr,
nnzB, row_ptrB.DATA_PTR<int>(), colB.DATA_PTR<int>(),
descr, row_ptrC.DATA_PTR<int>(), &nnzC);
auto colC = at::empty(nnzC, indexA.options());
auto valueC = at::empty(nnzC, valueA.options());
CSRGEMM(valueC.scalar_type(), cusparse_handle,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, m,
n, k, descr, nnzA, valueA.DATA_PTR<scalar_t>(),
row_ptrA.DATA_PTR<int>(), colA.DATA_PTR<int>(), descr, nnzB,
valueB.DATA_PTR<scalar_t>(), row_ptrB.DATA_PTR<int>(),
colB.DATA_PTR<int>(), descr, valueC.DATA_PTR<scalar_t>(),
row_ptrC.DATA_PTR<int>(), colC.DATA_PTR<int>());
auto rowC = at::empty(nnzC, indexA.options());
cusparseXcsr2coo(cusparse_handle, row_ptrC.DATA_PTR<int>(), nnzC, m,
rowC.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto indexC = at::stack({rowC, colC}, 0).toType(at::kLong);
return std::make_tuple(indexC, valueC);
}
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
auto zero = at::zeros(num_nodes, row.options());
auto one = at::ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
}
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
int64_t num_nodes) {
// Assert already coalesced input.
row = degree(row, num_nodes).cumsum(0);
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col);
}
template <typename scalar_t>
__global__ void spspmm_bw_kernel(
const int64_t *__restrict__ index, scalar_t *__restrict__ value,
const int64_t *__restrict__ rowA, const int64_t *__restrict__ colA,
const scalar_t *__restrict__ valueA, const int64_t *__restrict__ rowB,
const int64_t *__restrict__ colB, const scalar_t *__restrict__ valueB,
const size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t e = idx; e < numel; e += stride) {
int64_t i = index[e], j = index[numel + e];
for (ptrdiff_t dA = rowA[i]; dA < rowA[i + 1]; dA++) {
int64_t cA = colA[dA];
for (ptrdiff_t dB = rowB[j]; dB < rowB[j + 1]; dB++) {
int64_t cB = colB[dB];
if (cA == cB) {
value[e] += valueA[dA] * valueB[dB];
}
if (cB >= cA) {
break;
}
}
}
}
}
at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, size_t rowA_max, size_t rowB_max) {
cudaSetDevice(index.get_device());
auto value = at::zeros(index.size(1), valueA.options());
at::Tensor rowA, colA;
std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
at::Tensor rowB, colB;
std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
index.DATA_PTR<int64_t>(), value.DATA_PTR<scalar_t>(),
rowA.DATA_PTR<int64_t>(), colA.DATA_PTR<int64_t>(),
valueA.DATA_PTR<scalar_t>(), rowB.DATA_PTR<int64_t>(),
colB.DATA_PTR<int64_t>(), valueB.DATA_PTR<scalar_t>(), value.numel());
});
return value;
}
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src);
std::tuple<at::Tensor, at::Tensor> unique(at::Tensor src) {
CHECK_CUDA(src);
return unique_cuda(src);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("unique", &unique, "Unique (CUDA)");
}
#include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void unique_cuda_kernel(scalar_t *__restrict__ src, bool *mask,
size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
if (i == 0 || src[i] != src[i - 1]) {
mask[i] = true;
}
}
}
std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
cudaSetDevice(src.get_device());
at::Tensor perm;
std::tie(src, perm) = src.sort();
auto mask = at::zeros(src.numel(), src.options().dtype(at::kBool));
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] {
unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
src.DATA_PTR<scalar_t>(), mask.DATA_PTR<bool>(), src.numel());
});
src = src.masked_select(mask);
perm = perm.masked_select(mask);
return std::make_tuple(src, perm);
}
import platform
import os
import os.path as osp
import sys
import glob
from setuptools import setup, find_packages
from sys import argv
import torch
from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
extra_compile_args = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3']
ext_modules = [
CppExtension('torch_sparse.spspmm_cpu', ['cpu/spspmm.cpp'],
extra_compile_args=extra_compile_args)
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
GPU = True
for arg in argv:
if arg == '--cpu':
GPU = False
argv.remove(arg)
if CUDA_HOME is not None and GPU:
if platform.system() == 'Windows':
extra_link_args = ['cusparse.lib']
else:
extra_link_args = ['-lcusparse', '-l', 'cusparse']
ext_modules += [
CUDAExtension('torch_sparse.spspmm_cuda',
['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args),
CUDAExtension('torch_sparse.unique_cuda',
['cuda/unique.cpp', 'cuda/unique_kernel.cu'],
extra_compile_args=extra_compile_args),
]
__version__ = '0.4.4'
url = 'https://github.com/rusty1s/pytorch_sparse'
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
if os.getenv('FORCE_CUDA', '0') == '1':
WITH_CUDA = True
if os.getenv('FORCE_NON_CUDA', '0') == '1':
WITH_CUDA = False
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions():
Extension = CppExtension
define_macros = []
extra_compile_args = {'cxx': [], 'nvcc': []}
extra_link_args = []
# Windows users: Edit both of these to contain your VS include path, i.e.:
# extra_compile_args['cxx'] += ['-I{VISUAL_STUDIO_DIR}\\include']
# extra_compile_args['nvcc'] += ['-I{VISUAL_STUDIO_DIR}\\include']
if WITH_CUDA:
Extension = CUDAExtension
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
extra_compile_args['cxx'] += ['-O0']
extra_compile_args['nvcc'] += nvcc_flags
if sys.platform == 'win32':
extra_link_args = ['cusparse.lib']
else:
extra_link_args = ['-lcusparse', '-l', 'cusparse']
if sys.platform == 'win32':
extra_compile_args['cxx'] += ['/MP']
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
extensions = []
for main in main_files:
name = main.split(os.sep)[-1][:-4]
sources = [main, osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')]
if WITH_CUDA:
sources += [osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')]
extension = Extension(
f'torch_sparse._{name}',
sources,
include_dirs=[extensions_dir],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
extensions += [extension]
return extensions
install_requires = ['scipy']
setup_requires = ['pytest-runner']
......@@ -48,23 +73,20 @@ tests_require = ['pytest', 'pytest-cov']
setup(
name='torch_sparse',
version=__version__,
description=('PyTorch Extension Library of Optimized Autograd Sparse '
'Matrix Operations'),
version='1.0.0',
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url=url,
download_url='{}/archive/{}.tar.gz'.format(url, __version__),
keywords=[
'pytorch',
'sparse',
'sparse-matrices',
'autograd',
],
url='https://github.com/rusty1s/pytorch_sparse',
description=('PyTorch Extension Library of Optimized Autograd Sparse '
'Matrix Operations'),
keywords=['pytorch', 'sparse', 'sparse-matrices', 'autograd'],
license='MIT',
install_requires=install_requires,
setup_requires=setup_requires,
tests_require=tests_require,
ext_modules=ext_modules,
cmdclass=cmdclass,
ext_modules=get_extensions() if not BUILD_DOCS else [],
cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
},
packages=find_packages(),
)
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat, cat_diag
from .utils import devices, tensor
@pytest.mark.parametrize('device', devices)
def test_cat(device):
row, col = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
mat1 = SparseTensor(row=row, col=col)
mat1.fill_cache_()
row, col = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
mat2 = SparseTensor(row=row, col=col)
mat2.fill_cache_()
out = cat([mat1, mat2], dim=0)
assert out.to_dense().tolist() == [[1, 1, 0], [0, 0, 1], [1, 1, 0],
[0, 1, 0], [1, 0, 0]]
assert out.storage.has_row()
assert out.storage.has_rowptr()
assert out.storage.has_rowcount()
assert out.storage.num_cached_keys() == 1
out = cat([mat1, mat2], dim=1)
assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1],
[0, 0, 0, 1, 0]]
assert out.storage.has_row()
assert not out.storage.has_rowptr()
assert out.storage.num_cached_keys() == 2
out = cat_diag([mat1, mat2])
assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
[0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]]
assert out.storage.has_row()
assert out.storage.has_rowptr()
assert out.storage.num_cached_keys() == 5
value = torch.randn((mat1.nnz(), 4), device=device)
mat1 = mat1.set_value_(value, layout='coo')
out = cat([mat1, mat1], dim=-1)
assert out.storage.value().size() == (mat1.nnz(), 8)
assert out.storage.has_row()
assert out.storage.has_rowptr()
assert out.storage.num_cached_keys() == 5
from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_remove_diag(dtype, device):
row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(row=row, col=col, value=value)
mat.fill_cache_()
mat = mat.remove_diag()
assert mat.storage.row().tolist() == [0, 1]
assert mat.storage.col().tolist() == [1, 2]
assert mat.storage.value().tolist() == [2, 3]
assert mat.storage.num_cached_keys() == 2
assert mat.storage.rowcount().tolist() == [1, 1, 0]
assert mat.storage.colcount().tolist() == [0, 1, 1]
mat = SparseTensor(row=row, col=col, value=value)
mat.fill_cache_()
mat = mat.remove_diag(k=1)
assert mat.storage.row().tolist() == [0, 2]
assert mat.storage.col().tolist() == [0, 2]
assert mat.storage.value().tolist() == [1, 4]
assert mat.storage.num_cached_keys() == 2
assert mat.storage.rowcount().tolist() == [1, 0, 1]
assert mat.storage.colcount().tolist() == [1, 0, 1]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_set_diag(dtype, device):
row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(row=row, col=col, value=value)
mat = mat.set_diag(tensor([-8, -8], dtype, device), k=-1)
mat = mat.set_diag(tensor([-8], dtype, device), k=1)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_fill_diag(dtype, device):
row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(row=row, col=col, value=value)
mat = mat.fill_diag(-8, k=-1)
mat = mat.fill_diag(-8, k=1)
from torch_sparse import eye
from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
def test_eye():
index, value = eye(3)
assert index.tolist() == [[0, 1, 2], [0, 1, 2]]
assert value.tolist() == [1, 1, 1]
from .utils import dtypes, devices
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_eye(dtype, device):
options = torch.tensor(0, dtype=dtype, device=device)
mat = SparseTensor.eye(3, options=options)
assert mat.storage.sparse_sizes() == (3, 3)
assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.value().tolist() == [1, 1, 1]
assert mat.storage.num_cached_keys() == 0
mat = SparseTensor.eye(3, options=options, has_value=False)
assert mat.storage.sparse_sizes() == (3, 3)
assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.value() is None
assert mat.storage.num_cached_keys() == 0
mat = SparseTensor.eye(3, 4, options=options, fill_cache=True)
assert mat.storage.sparse_sizes() == (3, 4)
assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.num_cached_keys() == 5
assert mat.storage.rowcount().tolist() == [1, 1, 1]
assert mat.storage.colptr().tolist() == [0, 1, 2, 3, 3]
assert mat.storage.colcount().tolist() == [1, 1, 1, 0]
assert mat.storage.csr2csc().tolist() == [0, 1, 2]
assert mat.storage.csc2csr().tolist() == [0, 1, 2]
mat = SparseTensor.eye(4, 3, options=options, fill_cache=True)
assert mat.storage.sparse_sizes() == (4, 3)
assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3, 3]
assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.num_cached_keys() == 5
assert mat.storage.rowcount().tolist() == [1, 1, 1, 0]
assert mat.storage.colptr().tolist() == [0, 1, 2, 3]
assert mat.storage.colcount().tolist() == [1, 1, 1]
assert mat.storage.csr2csc().tolist() == [0, 1, 2]
assert mat.storage.csc2csr().tolist() == [0, 1, 2]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment