"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "8273c3f4500ac9c01e95f1d78fa3f5752aafdca7"
kernel_decl.h 7.53 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/kernel_decl.h
 * \brief Sparse matrix format-specific operator declarations.
 */
#ifndef DGL_ARRAY_KERNEL_DECL_H_
#define DGL_ARRAY_KERNEL_DECL_H_

#include <dgl/bcast.h>
#include <dgl/base_heterograph.h>
11
#include <dgl/runtime/ndarray.h>
12
13
14

#include <string>
#include <vector>
15
#include <utility>
16
17
18
19
20
21
22

namespace dgl {
namespace aten {

/*!
 * \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format.
 */
23
template <int XPU, typename IdType, int bits>
24
25
26
27
28
29
30
31
void SpMMCsr(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const aten::CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux);

32
33
34
35
36
37
38
39
40
41
/*!
 * \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format
 with heterograph support.
 */
template <int XPU, typename IdType, int bits>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const std::vector<CSRMatrix>& csr,
             const std::vector<NDArray>& ufeat,
             const std::vector<NDArray>& efeat,
42
43
             std::vector<NDArray>* out,
             std::vector<std::vector<NDArray>>* out_aux,
44
45
             const std::vector<dgl_type_t>& ufeat_eid,
             const std::vector<dgl_type_t>& out_eid);
46
47
48
/*!
 * \brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format.
 */
49
template <int XPU, typename IdType, int bits>
50
51
52
53
54
55
56
57
58
59
60
void SpMMCoo(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const aten::COOMatrix& coo,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux);

/*!
 * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format.
 */
61
template <int XPU, typename IdType, int bits>
62
63
64
void SDDMMCsr(const std::string& op,
              const BcastOff& bcast,
              const aten::CSRMatrix& csr,
65
66
67
68
69
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target);
70
/*!
71
 * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr
72
73
74
75
76
77
78
79
80
81
82
83
84
 format with heterograph support.
  */
template <int XPU, typename IdType, int bits>
void SDDMMCsrHetero(const std::string& op,
              const BcastOff& bcast,
              const std::vector<CSRMatrix>& vec_csr,
              const std::vector<NDArray>& vec_lhs,
              const std::vector<NDArray>& vec_rhs,
              std::vector<NDArray> vec_out,
              int lhs_target,
              int rhs_target,
              const std::vector<dgl_type_t>& ufeat_eid,
              const std::vector<dgl_type_t>& out_eid);
85
86
87
88

/*!
 * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format.
 */
89
template <int XPU, typename IdType, int bits>
90
91
92
void SDDMMCoo(const std::string& op,
              const BcastOff& bcast,
              const aten::COOMatrix& coo,
93
94
95
96
97
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target);
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
/*!
 * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo
 format with heterograph support.
  */
template <int XPU, typename IdType, int bits>
void SDDMMCooHetero(const std::string& op,
              const BcastOff& bcast,
              const std::vector<COOMatrix>& vec_coo,
              const std::vector<NDArray>& vec_lhs,
              const std::vector<NDArray>& vec_rhs,
              std::vector<NDArray> vec_out,
              int lhs_target,
              int rhs_target,
              const std::vector<dgl_type_t>& lhs_eid,
              const std::vector<dgl_type_t>& rhs_eid);

Israt Nisa's avatar
Israt Nisa committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
/*!
 * \brief Generalized Dense Matrix-Matrix Multiplication according to relation types.
 */
template <int XPU, typename IdType, int bits>
void gatherMM(const NDArray A,
          const NDArray B,
          NDArray out,
          const NDArray idx_a,
          const NDArray idx_b,
          const int num_rel);

/*!
 * \brief Generalized Dense Matrix-Matrix Multiplication according to relation types.
 */
template <int XPU, typename IdType, int bits>
void gatherMM_scatter(const NDArray A,
          const NDArray B,
          NDArray out,
          const NDArray idx_a,
          const NDArray idx_b,
          const NDArray idx_c,
          const int num_rel, bool a_trans, bool b_trans);

/*!
 * \brief Generalized segmented dense Matrix-Matrix Multiplication.
 */
template <int XPU, typename IdType, int bits>
void segmentMM(const NDArray A,
          const NDArray B,
          NDArray out,
          const NDArray seglen_A,
          bool a_trans, bool b_trans);

148
149
150
/*!
 * \brief Segment reduce.
 */
151
template <int XPU, typename IdType, int bits>
152
153
154
155
156
157
void SegmentReduce(const std::string& op,
                   NDArray feat,
                   NDArray offsets,
                   NDArray out,
                   NDArray arg);

158
159
160
161
162
163
164
165
/*!
 * \brief Scatter Add on first dimension.
 */
template <int XPU, typename IdType, int bits>
void ScatterAdd(NDArray feat,
                NDArray idx,
                NDArray out);

166
167
168
169
170
171
172
173
174
175
176
/*!
 * \brief Update gradients for reduce operator max and min on first dimension.
 */
template <int XPU, typename IdType, int bits>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
                const std::string& op,
                const std::vector<NDArray>& feat,
                const std::vector<NDArray>& idx,
                const std::vector<NDArray>& idx_etype,
                std::vector<NDArray>* out);

177
178
179
/*!
 * \brief Backward function of segment cmp.
 */
180
template <int XPU, typename IdType, int bits>
181
182
183
184
void BackwardSegmentCmp(NDArray feat,
                        NDArray arg,
                        NDArray out);

185
186
187
/*!
 * \brief Sparse-sparse matrix multiplication
 *
188
189
190
191
192
193
194
195
 * \param A The left operand.
 * \param A_weights The weights of matrix as a 1D tensor.
 * \param B The right operand.
 * \param B_weights The weights of matrix as a 1D tensor.
 *
 * \note GPU implementation will cast the indices to 32 bit.
 * \note The zero entries in the result are not removed.
 * \note The CSR matrix should not have duplicate entries.
196
197
198
199
200
201
202
203
204
205
 */
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
    const CSRMatrix& A,
    NDArray A_weights,
    const CSRMatrix& B,
    NDArray B_weights);

/*!
 * \brief Sparse-sparse matrix summation.
206
207
208
209
210
211
212
 *
 * \param A The sparse matrices with the same size.
 * \param A_weights The weights of each sparse matrix as a 1D tensor.
 *
 * \note GPU implementation will cast the indices to 32 bit.
 * \note The zero entries in the result are not removed.
 * \note The CSR matrix should not have duplicate entries.
213
214
215
216
217
218
 */
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum(
    const std::vector<CSRMatrix>& A,
    const std::vector<NDArray>& A_weights);

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
/*!
 * \brief Edge_softmax_csr forward function on Csr format.
 */
template <int XPU, typename IdType, int bits>
void Edge_softmax_csr_forward(const std::string& op,
             const BcastOff& bcast,
             const aten::CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out);
/*!
 * \brief Edge_softmax_csr backward function on Csr format.
 */
template <int XPU, typename IdType, int bits>
void Edge_softmax_csr_backward(const std::string& op,
             const BcastOff& bcast,
             const aten::CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out);
239
240
241
242
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_KERNEL_DECL_H_