kernel_decl.h 6.69 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/kernel_decl.h
 * @brief Sparse matrix format-specific operator declarations.
5
6
7
8
9
 */
#ifndef DGL_ARRAY_KERNEL_DECL_H_
#define DGL_ARRAY_KERNEL_DECL_H_

#include <dgl/base_heterograph.h>
10
#include <dgl/bcast.h>
11
#include <dgl/runtime/ndarray.h>
12
13

#include <string>
14
#include <utility>
15
#include <vector>
16
17
18
19

namespace dgl {
namespace aten {

20
/**
21
 * @brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format.
22
 */
23
template <int XPU, typename IdType, typename DType>
24
25
26
27
void SpMMCsr(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const aten::CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
28

29
/**
30
 * @brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format
31
 * with heterograph support.
32
 */
33
template <int XPU, typename IdType, typename DType>
34
35
36
37
38
39
40
void SpMMCsrHetero(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_eid,
    const std::vector<dgl_type_t>& out_eid);
41
/**
42
 * @brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format.
43
 */
44
template <int XPU, typename IdType, typename DType>
45
46
47
48
void SpMMCoo(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const aten::COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
49

50
/**
51
 * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format.
52
 */
53
template <int XPU, typename IdType, typename DType>
54
55
56
void SDDMMCsr(
    const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
57
/**
58
59
60
 * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format
 * with heterograph support.
 */
61
template <int XPU, typename IdType, typename DType>
62
63
64
65
66
67
void SDDMMCsrHetero(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& ufeat_eid,
    const std::vector<dgl_type_t>& out_eid);
68

69
/**
70
 * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format.
71
 */
72
template <int XPU, typename IdType, typename DType>
73
74
75
void SDDMMCoo(
    const std::string& op, const BcastOff& bcast, const aten::COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
76

77
/**
78
79
80
 * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format
 * with heterograph support.
 */
81
template <int XPU, typename IdType, typename DType>
82
83
84
85
86
87
void SDDMMCooHetero(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,
    const std::vector<dgl_type_t>& rhs_eid);
88

89
/**
90
91
 * @brief Generalized Dense Matrix-Matrix Multiplication according to relation
 * types.
Israt Nisa's avatar
Israt Nisa committed
92
 */
93
template <int XPU, typename IdType, typename DType>
94
95
96
void GatherMM(
    const NDArray A, const NDArray B, NDArray out, const NDArray idx_a,
    const NDArray idx_b);
Israt Nisa's avatar
Israt Nisa committed
97

98
/**
99
100
 * @brief Generalized Dense Matrix-Matrix Multiplication according to relation
 * types.
Israt Nisa's avatar
Israt Nisa committed
101
 */
102
template <int XPU, typename IdType, typename DType>
103
104
105
void GatherMMScatter(
    const NDArray A, const NDArray B, NDArray out, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
Israt Nisa's avatar
Israt Nisa committed
106

107
/**
108
 * @brief Generalized segmented dense Matrix-Matrix Multiplication.
Israt Nisa's avatar
Israt Nisa committed
109
 */
110
template <int XPU, typename IdType, typename DType>
111
112
113
void SegmentMM(
    const NDArray A, const NDArray B, NDArray out, const NDArray seglen_A,
    bool a_trans, bool b_trans);
114

115
template <int XPU, typename IdType, typename DType>
116
117
void SegmentMMBackwardB(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
Israt Nisa's avatar
Israt Nisa committed
118

119
/**
120
 * @brief Segment reduce.
121
 */
122
template <int XPU, typename IdType, typename DType>
123
124
125
void SegmentReduce(
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
    NDArray arg);
126

127
/**
128
 * @brief Scatter Add on first dimension.
129
 */
130
template <int XPU, typename IdType, typename DType>
131
void ScatterAdd(NDArray feat, NDArray idx, NDArray out);
132

133
/**
134
 * @brief Update gradients for reduce operator max and min on first dimension.
135
 */
136
template <int XPU, typename IdType, typename DType>
137
138
139
140
void UpdateGradMinMax_hetero(
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
141

142
/**
143
 * @brief Backward function of segment cmp.
144
 */
145
template <int XPU, typename IdType, typename DType>
146
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out);
147

148
/**
149
 * @brief Sparse-sparse matrix multiplication
150
 *
151
152
153
154
 * @param A The left operand.
 * @param A_weights The weights of matrix as a 1D tensor.
 * @param B The right operand.
 * @param B_weights The weights of matrix as a 1D tensor.
155
 *
156
157
158
 * @note GPU implementation will cast the indices to 32 bit.
 * @note The zero entries in the result are not removed.
 * @note The CSR matrix should not have duplicate entries.
159
160
161
 */
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
162
    const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,
163
164
    NDArray B_weights);

165
/**
166
 * @brief Sparse-sparse matrix summation.
167
 *
168
169
 * @param A The sparse matrices with the same size.
 * @param A_weights The weights of each sparse matrix as a 1D tensor.
170
 *
171
172
173
 * @note GPU implementation will cast the indices to 32 bit.
 * @note The zero entries in the result are not removed.
 * @note The CSR matrix should not have duplicate entries.
174
175
176
 */
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum(
177
    const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights);
178

179
/**
180
 * @brief Edge_softmax_csr forward function on Csr format.
181
 */
182
template <int XPU, typename IdType, typename DType>
183
184
185
void Edge_softmax_csr_forward(
    const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
186
/**
187
 * @brief Edge_softmax_csr backward function on Csr format.
188
 */
189
template <int XPU, typename IdType, typename DType>
190
191
192
void Edge_softmax_csr_backward(
    const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
193
194
195
196
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_KERNEL_DECL_H_