kernel_decl.h 3.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/kernel_decl.h
 * \brief Sparse matrix format-specific operator declarations.
 */
#ifndef DGL_ARRAY_KERNEL_DECL_H_
#define DGL_ARRAY_KERNEL_DECL_H_

#include <dgl/bcast.h>
#include <dgl/base_heterograph.h>
11
#include <dgl/runtime/ndarray.h>
12
13
14

#include <string>
#include <vector>
15
#include <utility>
16
17
18
19
20
21
22

namespace dgl {
namespace aten {

/*!
 * \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format.
 */
23
template <int XPU, typename IdType, int bits>
24
25
26
27
28
29
30
31
32
33
34
void SpMMCsr(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const aten::CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux);

/*!
 * \brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format.
 */
35
template <int XPU, typename IdType, int bits>
36
37
38
39
40
41
42
43
44
45
46
void SpMMCoo(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const aten::COOMatrix& coo,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux);

/*!
 * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format.
 */
47
template <int XPU, typename IdType, int bits>
48
49
50
void SDDMMCsr(const std::string& op,
              const BcastOff& bcast,
              const aten::CSRMatrix& csr,
51
52
53
54
55
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target);
56
57
58
59

/*!
 * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format.
 */
60
template <int XPU, typename IdType, int bits>
61
62
63
void SDDMMCoo(const std::string& op,
              const BcastOff& bcast,
              const aten::COOMatrix& coo,
64
65
66
67
68
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target);
69

70
71
72
/*!
 * \brief Segment reduce.
 */
73
template <int XPU, typename IdType, int bits>
74
75
76
77
78
79
void SegmentReduce(const std::string& op,
                   NDArray feat,
                   NDArray offsets,
                   NDArray out,
                   NDArray arg);

80
81
82
83
84
85
86
87
/*!
 * \brief Scatter Add on first dimension.
 */
template <int XPU, typename IdType, int bits>
void ScatterAdd(NDArray feat,
                NDArray idx,
                NDArray out);

88
89
90
/*!
 * \brief Backward function of segment cmp.
 */
91
template <int XPU, typename IdType, int bits>
92
93
94
95
void BackwardSegmentCmp(NDArray feat,
                        NDArray arg,
                        NDArray out);

96
97
98
/*!
 * \brief Sparse-sparse matrix multiplication
 *
99
100
101
102
103
104
105
106
 * \param A The left operand.
 * \param A_weights The weights of matrix as a 1D tensor.
 * \param B The right operand.
 * \param B_weights The weights of matrix as a 1D tensor.
 *
 * \note GPU implementation will cast the indices to 32 bit.
 * \note The zero entries in the result are not removed.
 * \note The CSR matrix should not have duplicate entries.
107
108
109
110
111
112
113
114
115
116
 */
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
    const CSRMatrix& A,
    NDArray A_weights,
    const CSRMatrix& B,
    NDArray B_weights);

/*!
 * \brief Sparse-sparse matrix summation.
117
118
119
120
121
122
123
 *
 * \param A The sparse matrices with the same size.
 * \param A_weights The weights of each sparse matrix as a 1D tensor.
 *
 * \note GPU implementation will cast the indices to 32 bit.
 * \note The zero entries in the result are not removed.
 * \note The CSR matrix should not have duplicate entries.
124
125
126
127
128
129
 */
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum(
    const std::vector<CSRMatrix>& A,
    const std::vector<NDArray>& A_weights);

130
131
132
133
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_KERNEL_DECL_H_