Unverified Commit f1ee3e31 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sparse] Sparse matrix reduction C++ part (#5012)

* initial commit

* lint

* rename

* address changes

* oops

* fix?

* fix?
parent 00f20999
/**
* Copyright (c) 2022 by Contributors
* @file sparse/reduction.h
* @brief DGL C++ sparse matrix reduction operators.
*/
#ifndef SPARSE_REDUCTION_H_
#define SPARSE_REDUCTION_H_
#include <sparse/sparse_matrix.h>
#include <string>
namespace dgl {
namespace sparse {
/**
* @brief Reduces a sparse matrix along the specified sparse dimension.
*
* @param A The sparse matrix.
* @param dim The sparse dimension to reduce along. Must be either 0 (rows) or
* 1 (columns).
* @param reduce The reduce operator. Must be either "sum", "smin", "smax",
* "mean", or "sprod".
*
* @return Tensor
*/
torch::Tensor Reduce(
const c10::intrusive_ptr<SparseMatrix>& A, const std::string& reduce,
const torch::optional<int64_t>& dim = torch::nullopt);
inline torch::Tensor ReduceSum(
const c10::intrusive_ptr<SparseMatrix>& A,
const torch::optional<int64_t>& dim = torch::nullopt) {
return Reduce(A, "sum", dim);
}
inline torch::Tensor ReduceMin(
const c10::intrusive_ptr<SparseMatrix>& A,
const torch::optional<int64_t>& dim = torch::nullopt) {
return Reduce(A, "smin", dim);
}
inline torch::Tensor ReduceMax(
const c10::intrusive_ptr<SparseMatrix>& A,
const torch::optional<int64_t>& dim = torch::nullopt) {
return Reduce(A, "smax", dim);
}
inline torch::Tensor ReduceMean(
const c10::intrusive_ptr<SparseMatrix>& A,
const torch::optional<int64_t>& dim = torch::nullopt) {
return Reduce(A, "smean", dim);
}
inline torch::Tensor ReduceProd(
const c10::intrusive_ptr<SparseMatrix>& A,
const torch::optional<int64_t>& dim = torch::nullopt) {
return Reduce(A, "sprod", dim);
}
} // namespace sparse
} // namespace dgl
#endif // SPARSE_REDUCTION_H_
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
// clang-format on // clang-format on
#include <sparse/elementwise_op.h> #include <sparse/elementwise_op.h>
#include <sparse/reduction.h>
#include <sparse/sparse_matrix.h> #include <sparse/sparse_matrix.h>
#include <torch/custom_class.h> #include <torch/custom_class.h>
#include <torch/script.h> #include <torch/script.h>
...@@ -29,6 +30,12 @@ TORCH_LIBRARY(dgl_sparse, m) { ...@@ -29,6 +30,12 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("create_from_csr", &CreateFromCSR) .def("create_from_csr", &CreateFromCSR)
.def("create_from_csc", &CreateFromCSC) .def("create_from_csc", &CreateFromCSC)
.def("spsp_add", &SpSpAdd) .def("spsp_add", &SpSpAdd)
.def("reduce", &Reduce)
.def("sum", &ReduceSum)
.def("smean", &ReduceMean)
.def("smin", &ReduceMin)
.def("smax", &ReduceMax)
.def("sprod", &ReduceProd)
.def("val_like", &CreateValLike); .def("val_like", &CreateValLike);
} }
......
/**
* Copyright (c) 2022 by Contributors
* @file reduction.cc
* @brief DGL C++ sparse matrix reduction operator implementation.
*/
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on
#include <sparse/elementwise_op.h>
#include <sparse/reduction.h>
#include <sparse/sparse_matrix.h>
#include <torch/script.h>
#include <string>
#include <vector>
namespace dgl {
namespace sparse {
namespace {
torch::Tensor ReduceAlong(
const c10::intrusive_ptr<SparseMatrix>& A, const std::string& reduce,
int64_t dim) {
auto value = A->value();
auto coo = A->COOPtr();
std::string reduce_op;
if (reduce == "sum") {
reduce_op = "sum";
} else if (reduce == "smin") {
reduce_op = "amin";
} else if (reduce == "smax") {
reduce_op = "amax";
} else if (reduce == "smean") {
reduce_op = "mean";
} else if (reduce == "sprod") {
reduce_op = "prod";
} else {
TORCH_CHECK(false, "unknown reduce function ", reduce);
return torch::Tensor();
}
// Create the output tensor with shape
//
// [A.num_rows if dim == 1 else A.num_cols] + A.val.shape[1:]
std::vector<int64_t> output_shape = value.sizes().vec();
std::vector<int64_t> view_dims(output_shape.size(), 1);
view_dims[0] = -1;
torch::Tensor idx;
if (dim == 0) {
output_shape[0] = coo->num_cols;
idx = coo->col.view(view_dims).expand_as(value);
} else if (dim == 1) {
output_shape[0] = coo->num_rows;
idx = coo->row.view(view_dims).expand_as(value);
}
torch::Tensor out = torch::zeros(output_shape, value.options());
if (dim == 0) {
out.scatter_reduce_(0, idx, value, reduce_op, false);
} else if (dim == 1) {
out.scatter_reduce_(0, idx, value, reduce_op, false);
}
return out;
}
torch::Tensor ReduceAll(
const c10::intrusive_ptr<SparseMatrix>& A, const std::string& reduce) {
if (reduce == "sum") {
return A->value().sum(0);
} else if (reduce == "smin") {
return A->value().amin(0);
} else if (reduce == "smax") {
return A->value().amax(0);
} else if (reduce == "smean") {
return A->value().mean(0);
} else if (reduce == "sprod") {
return A->value().prod(0);
}
TORCH_CHECK(false, "unknown reduce function ", reduce);
return torch::Tensor();
}
} // namespace
torch::Tensor Reduce(
const c10::intrusive_ptr<SparseMatrix>& A, const std::string& reduce,
const torch::optional<int64_t>& dim) {
return dim.has_value() ? ReduceAlong(A, reduce, dim.value())
: ReduceAll(A, reduce);
}
} // namespace sparse
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment