reduction.h 1.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/**
 *  Copyright (c) 2022 by Contributors
 * @file sparse/reduction.h
 * @brief DGL C++ sparse matrix reduction operators.
 */
#ifndef SPARSE_REDUCTION_H_
#define SPARSE_REDUCTION_H_

#include <sparse/sparse_matrix.h>

#include <string>

namespace dgl {
namespace sparse {

/**
 * @brief Reduces a sparse matrix along the specified sparse dimension.
 *
 * @param A The sparse matrix.
 * @param dim The sparse dimension to reduce along.  Must be either 0 (rows) or
 * 1 (columns).
 * @param reduce The reduce operator.  Must be either "sum", "smin", "smax",
 * "mean", or "sprod".
 *
 * @return Tensor
 */
torch::Tensor Reduce(
    const c10::intrusive_ptr<SparseMatrix>& A, const std::string& reduce,
    const torch::optional<int64_t>& dim = torch::nullopt);

inline torch::Tensor ReduceSum(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const torch::optional<int64_t>& dim = torch::nullopt) {
  return Reduce(A, "sum", dim);
}

inline torch::Tensor ReduceMin(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const torch::optional<int64_t>& dim = torch::nullopt) {
  return Reduce(A, "smin", dim);
}

inline torch::Tensor ReduceMax(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const torch::optional<int64_t>& dim = torch::nullopt) {
  return Reduce(A, "smax", dim);
}

inline torch::Tensor ReduceMean(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const torch::optional<int64_t>& dim = torch::nullopt) {
  return Reduce(A, "smean", dim);
}

inline torch::Tensor ReduceProd(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const torch::optional<int64_t>& dim = torch::nullopt) {
  return Reduce(A, "sprod", dim);
}

}  // namespace sparse
}  // namespace dgl

#endif  // SPARSE_REDUCTION_H_