array_op.h 12 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019-2022 by Contributors
3
4
 * @file array/array_op.h
 * @brief Array operator templates
5
6
7
8
9
 */
#ifndef DGL_ARRAY_ARRAY_OP_H_
#define DGL_ARRAY_ARRAY_OP_H_

#include <dgl/array.h>
10
#include <dgl/graph_traversal.h>
11

12
13
#include <tuple>
#include <utility>
14
#include <vector>
15
16
17
18
19

namespace dgl {
namespace aten {
namespace impl {

20
21
template <DGLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DGLContext ctx);
22

23
24
template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DGLContext ctx);
25

26
template <DGLDeviceType XPU, typename IdType>
27
28
IdArray AsNumBits(IdArray arr, uint8_t bits);

29
template <DGLDeviceType XPU, typename IdType, typename Op>
30
31
IdArray BinaryElewise(IdArray lhs, IdArray rhs);

32
template <DGLDeviceType XPU, typename IdType, typename Op>
33
34
IdArray BinaryElewise(IdArray lhs, IdType rhs);

35
template <DGLDeviceType XPU, typename IdType, typename Op>
36
37
IdArray BinaryElewise(IdType lhs, IdArray rhs);

38
template <DGLDeviceType XPU, typename IdType, typename Op>
39
40
IdArray UnaryElewise(IdArray array);

41
template <DGLDeviceType XPU, typename DType, typename IdType>
42
NDArray IndexSelect(NDArray array, IdArray index);
43

44
template <DGLDeviceType XPU, typename DType>
45
DType IndexSelect(NDArray array, int64_t index);
46

47
template <DGLDeviceType XPU, typename DType>
Jinjing Zhou's avatar
Jinjing Zhou committed
48
49
IdArray NonZero(BoolArray bool_arr);

50
51
52
template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(NDArray array);

53
template <DGLDeviceType XPU, typename DType>
54
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits);
55

56
template <DGLDeviceType XPU, typename DType, typename IdType>
57
58
NDArray Scatter(NDArray array, IdArray indices);

59
template <DGLDeviceType XPU, typename DType, typename IdType>
60
61
void Scatter_(IdArray index, NDArray value, NDArray out);

62
template <DGLDeviceType XPU, typename DType, typename IdType>
63
64
NDArray Repeat(NDArray array, IdArray repeats);

65
template <DGLDeviceType XPU, typename IdType>
66
67
IdArray Relabel_(const std::vector<IdArray>& arrays);

68
template <DGLDeviceType XPU, typename IdType>
69
70
NDArray Concat(const std::vector<IdArray>& arrays);

71
template <DGLDeviceType XPU, typename DType>
72
73
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value);

74
template <DGLDeviceType XPU, typename DType, typename IdType>
75
76
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);

77
template <DGLDeviceType XPU, typename IdType>
78
79
IdArray CumSum(IdArray array, bool prepend_zero);

80
81
// sparse arrays

82
template <DGLDeviceType XPU, typename IdType>
83
84
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);

85
template <DGLDeviceType XPU, typename IdType>
86
87
runtime::NDArray CSRIsNonZero(
    CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
88

89
template <DGLDeviceType XPU, typename IdType>
90
91
bool CSRHasDuplicate(CSRMatrix csr);

92
template <DGLDeviceType XPU, typename IdType>
93
94
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row);

95
template <DGLDeviceType XPU, typename IdType>
96
97
runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row);

98
template <DGLDeviceType XPU, typename IdType>
99
100
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row);

101
template <DGLDeviceType XPU, typename IdType>
102
103
runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);

104
template <DGLDeviceType XPU, typename IdType>
105
106
bool CSRIsSorted(CSRMatrix csr);

107
template <DGLDeviceType XPU, typename IdType, typename DType>
108
runtime::NDArray CSRGetData(
109
110
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
    bool return_eids, runtime::NDArray weights, DType filler);
111

112
template <DGLDeviceType XPU, typename IdType, typename DType>
113
114
115
runtime::NDArray CSRGetData(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
    runtime::NDArray weights, DType filler) {
116
117
  return CSRGetData<XPU, IdType, DType>(
      csr, rows, cols, false, weights, filler);
118
119
}

120
template <DGLDeviceType XPU, typename IdType>
121
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
122
123
  return CSRGetData<XPU, IdType, IdType>(
      csr, rows, cols, true, NullArray(rows->dtype), -1);
124
}
125

126
template <DGLDeviceType XPU, typename IdType>
127
128
129
std::vector<runtime::NDArray> CSRGetDataAndIndices(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);

130
template <DGLDeviceType XPU, typename IdType>
131
132
133
CSRMatrix CSRTranspose(CSRMatrix csr);

// Convert CSR to COO
134
template <DGLDeviceType XPU, typename IdType>
135
136
137
COOMatrix CSRToCOO(CSRMatrix csr);

// Convert CSR to COO using data array as order
138
template <DGLDeviceType XPU, typename IdType>
139
140
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr);

141
template <DGLDeviceType XPU, typename IdType>
142
143
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);

144
template <DGLDeviceType XPU, typename IdType>
145
146
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);

147
template <DGLDeviceType XPU, typename IdType>
148
149
CSRMatrix CSRSliceMatrix(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
150

151
template <DGLDeviceType XPU, typename IdType>
152
153
void CSRSort_(CSRMatrix* csr);

154
template <DGLDeviceType XPU, typename IdType, typename TagType>
155
std::pair<CSRMatrix, NDArray> CSRSortByTag(
156
    const CSRMatrix& csr, IdArray tag_array, int64_t num_tags);
157

158
template <DGLDeviceType XPU, typename IdType>
159
160
CSRMatrix CSRReorder(
    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
Da Zheng's avatar
Da Zheng committed
161

162
template <DGLDeviceType XPU, typename IdType>
163
164
COOMatrix COOReorder(
    COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
Da Zheng's avatar
Da Zheng committed
165

166
template <DGLDeviceType XPU, typename IdType>
167
168
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);

169
170
template <DGLDeviceType XPU, typename IdType, typename FloatType>
std::pair<COOMatrix, FloatArray> CSRLaborSampling(
171
172
    CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,
    int importance_sampling, IdArray random_seed, float seed2_contribution,
173
174
    IdArray NIDs);

175
// FloatType is the type of probability data.
176
template <DGLDeviceType XPU, typename IdType, typename DType>
177
COOMatrix CSRRowWiseSampling(
178
179
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace);
180

181
// FloatType is the type of probability data.
182
template <DGLDeviceType XPU, typename IdType, typename DType>
183
COOMatrix CSRRowWisePerEtypeSampling(
184
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
185
    const std::vector<int64_t>& num_samples,
186
187
    const std::vector<NDArray>& prob_or_mask, bool replace,
    bool rowwise_etype_sorted);
188

189
template <DGLDeviceType XPU, typename IdType>
190
191
192
COOMatrix CSRRowWiseSamplingUniform(
    CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace);

193
template <DGLDeviceType XPU, typename IdType>
194
COOMatrix CSRRowWisePerEtypeSamplingUniform(
195
196
197
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_samples, bool replace,
    bool rowwise_etype_sorted);
198

199
// FloatType is the type of weight data.
200
template <DGLDeviceType XPU, typename IdType, typename DType>
201
COOMatrix CSRRowWiseTopk(
202
    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);
203

204
template <DGLDeviceType XPU, typename IdType, typename FloatType>
205
COOMatrix CSRRowWiseSamplingBiased(
206
207
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
    FloatArray bias, bool replace);
208

209
template <DGLDeviceType XPU, typename IdType>
210
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
211
212
    const CSRMatrix& csr, int64_t num_samples, int num_trials,
    bool exclude_self_loops, bool replace, double redundancy);
213

214
// Union CSRMatrixes
215
template <DGLDeviceType XPU, typename IdType>
216
217
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);

218
template <DGLDeviceType XPU, typename IdType>
219
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr);
220

221
////////////////////////////////////////////////////////////////////////////////
Da Zheng's avatar
Da Zheng committed
222

223
template <DGLDeviceType XPU, typename IdType>
224
225
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);

226
template <DGLDeviceType XPU, typename IdType>
227
228
runtime::NDArray COOIsNonZero(
    COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
229

230
template <DGLDeviceType XPU, typename IdType>
231
232
bool COOHasDuplicate(COOMatrix coo);

233
template <DGLDeviceType XPU, typename IdType>
234
235
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row);

236
template <DGLDeviceType XPU, typename IdType>
237
238
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);

239
template <DGLDeviceType XPU, typename IdType>
240
241
std::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(
    COOMatrix coo, int64_t row);
242

243
template <DGLDeviceType XPU, typename IdType>
244
245
246
std::vector<runtime::NDArray> COOGetDataAndIndices(
    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);

247
template <DGLDeviceType XPU, typename IdType>
248
249
runtime::NDArray COOGetData(
    COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
250

251
template <DGLDeviceType XPU, typename IdType>
252
253
COOMatrix COOTranspose(COOMatrix coo);

254
template <DGLDeviceType XPU, typename IdType>
255
256
CSRMatrix COOToCSR(COOMatrix coo);

257
template <DGLDeviceType XPU, typename IdType>
258
259
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);

260
template <DGLDeviceType XPU, typename IdType>
261
262
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);

263
template <DGLDeviceType XPU, typename IdType>
264
265
COOMatrix COOSliceMatrix(
    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
266

267
template <DGLDeviceType XPU, typename IdType>
268
269
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);

270
template <DGLDeviceType XPU, typename IdType>
271
272
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);

273
template <DGLDeviceType XPU, typename IdType>
274
275
void COOSort_(COOMatrix* mat, bool sort_column);

276
template <DGLDeviceType XPU, typename IdType>
277
std::pair<bool, bool> COOIsSorted(COOMatrix coo);
278

279
template <DGLDeviceType XPU, typename IdType>
280
281
COOMatrix COORemove(COOMatrix coo, IdArray entries);

282
283
template <DGLDeviceType XPU, typename IdType, typename FloatType>
std::pair<COOMatrix, FloatArray> COOLaborSampling(
284
285
    COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,
    int importance_sampling, IdArray random_seed, float seed2_contribution,
286
287
    IdArray NIDs);

288
// FloatType is the type of probability data.
289
template <DGLDeviceType XPU, typename IdType, typename DType>
290
COOMatrix COORowWiseSampling(
291
292
    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace);
293

294
// FloatType is the type of probability data.
295
template <DGLDeviceType XPU, typename IdType, typename DType>
296
COOMatrix COORowWisePerEtypeSampling(
297
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
298
299
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace);
300

301
template <DGLDeviceType XPU, typename IdType>
302
303
304
COOMatrix COORowWiseSamplingUniform(
    COOMatrix mat, IdArray rows, int64_t num_samples, bool replace);

305
template <DGLDeviceType XPU, typename IdType>
306
COOMatrix COORowWisePerEtypeSamplingUniform(
307
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
308
    const std::vector<int64_t>& num_samples, bool replace);
309

310
// FloatType is the type of weight data.
311
template <DGLDeviceType XPU, typename IdType, typename FloatType>
312
313
COOMatrix COORowWiseTopk(
    COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
314

315
316
///////////////////////// Graph Traverse routines //////////////////////////

317
template <DGLDeviceType XPU, typename IdType>
318
319
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);

320
template <DGLDeviceType XPU, typename IdType>
321
322
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);

323
template <DGLDeviceType XPU, typename IdType>
324
325
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);

326
template <DGLDeviceType XPU, typename IdType>
327
328
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);

329
template <DGLDeviceType XPU, typename IdType>
330
331
332
Frontiers DGLDFSLabeledEdges(
    const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
    const bool has_nontree_edge, const bool return_labels);
333

334
template <DGLDeviceType XPU, typename IdType>
335
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
336

337
338
339
340
341
}  // namespace impl
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_ARRAY_OP_H_