array_op.h 11.9 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019-2022 by Contributors
3
4
 * @file array/array_op.h
 * @brief Array operator templates
5
6
7
8
9
 */
#ifndef DGL_ARRAY_ARRAY_OP_H_
#define DGL_ARRAY_ARRAY_OP_H_

#include <dgl/array.h>
10
#include <dgl/graph_traversal.h>
11

12
13
#include <tuple>
#include <utility>
14
#include <vector>
15
16
17
18
19

namespace dgl {
namespace aten {
namespace impl {

20
21
template <DGLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DGLContext ctx);
22

23
24
template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DGLContext ctx);
25

26
template <DGLDeviceType XPU, typename IdType>
27
28
IdArray AsNumBits(IdArray arr, uint8_t bits);

29
template <DGLDeviceType XPU, typename IdType, typename Op>
30
31
IdArray BinaryElewise(IdArray lhs, IdArray rhs);

32
template <DGLDeviceType XPU, typename IdType, typename Op>
33
34
IdArray BinaryElewise(IdArray lhs, IdType rhs);

35
template <DGLDeviceType XPU, typename IdType, typename Op>
36
37
IdArray BinaryElewise(IdType lhs, IdArray rhs);

38
template <DGLDeviceType XPU, typename IdType, typename Op>
39
40
IdArray UnaryElewise(IdArray array);

41
template <DGLDeviceType XPU, typename DType, typename IdType>
42
NDArray IndexSelect(NDArray array, IdArray index);
43

44
template <DGLDeviceType XPU, typename DType>
45
DType IndexSelect(NDArray array, int64_t index);
46

47
template <DGLDeviceType XPU, typename DType>
Jinjing Zhou's avatar
Jinjing Zhou committed
48
49
IdArray NonZero(BoolArray bool_arr);

50
51
52
template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(NDArray array);

53
template <DGLDeviceType XPU, typename DType>
54
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits);
55

56
template <DGLDeviceType XPU, typename DType, typename IdType>
57
58
NDArray Scatter(NDArray array, IdArray indices);

59
template <DGLDeviceType XPU, typename DType, typename IdType>
60
61
void Scatter_(IdArray index, NDArray value, NDArray out);

62
template <DGLDeviceType XPU, typename DType, typename IdType>
63
64
NDArray Repeat(NDArray array, IdArray repeats);

65
template <DGLDeviceType XPU, typename IdType>
66
67
IdArray Relabel_(const std::vector<IdArray>& arrays);

68
template <DGLDeviceType XPU, typename IdType>
69
70
NDArray Concat(const std::vector<IdArray>& arrays);

71
template <DGLDeviceType XPU, typename DType>
72
73
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value);

74
template <DGLDeviceType XPU, typename DType, typename IdType>
75
76
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);

77
template <DGLDeviceType XPU, typename IdType>
78
79
IdArray CumSum(IdArray array, bool prepend_zero);

80
81
// sparse arrays

82
template <DGLDeviceType XPU, typename IdType>
83
84
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);

85
template <DGLDeviceType XPU, typename IdType>
86
87
runtime::NDArray CSRIsNonZero(
    CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
88

89
template <DGLDeviceType XPU, typename IdType>
90
91
bool CSRHasDuplicate(CSRMatrix csr);

92
template <DGLDeviceType XPU, typename IdType>
93
94
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row);

95
template <DGLDeviceType XPU, typename IdType>
96
97
runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row);

98
template <DGLDeviceType XPU, typename IdType>
99
100
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row);

101
template <DGLDeviceType XPU, typename IdType>
102
103
runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);

104
template <DGLDeviceType XPU, typename IdType>
105
106
bool CSRIsSorted(CSRMatrix csr);

107
template <DGLDeviceType XPU, typename IdType, typename DType>
108
runtime::NDArray CSRGetData(
109
110
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
    bool return_eids, runtime::NDArray weights, DType filler);
111

112
template <DGLDeviceType XPU, typename IdType, typename DType>
113
114
115
runtime::NDArray CSRGetData(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
    runtime::NDArray weights, DType filler) {
116
117
  return CSRGetData<XPU, IdType, DType>(
      csr, rows, cols, false, weights, filler);
118
119
}

120
template <DGLDeviceType XPU, typename IdType>
121
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
122
123
  return CSRGetData<XPU, IdType, IdType>(
      csr, rows, cols, true, NullArray(rows->dtype), -1);
124
}
125

126
template <DGLDeviceType XPU, typename IdType>
127
128
129
std::vector<runtime::NDArray> CSRGetDataAndIndices(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);

130
template <DGLDeviceType XPU, typename IdType>
131
132
133
CSRMatrix CSRTranspose(CSRMatrix csr);

// Convert CSR to COO
134
template <DGLDeviceType XPU, typename IdType>
135
136
137
COOMatrix CSRToCOO(CSRMatrix csr);

// Convert CSR to COO using data array as order
138
template <DGLDeviceType XPU, typename IdType>
139
140
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr);

141
template <DGLDeviceType XPU, typename IdType>
142
143
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);

144
template <DGLDeviceType XPU, typename IdType>
145
146
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);

147
template <DGLDeviceType XPU, typename IdType>
148
149
CSRMatrix CSRSliceMatrix(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
150

151
template <DGLDeviceType XPU, typename IdType>
152
153
void CSRSort_(CSRMatrix* csr);

154
template <DGLDeviceType XPU, typename IdType, typename TagType>
155
std::pair<CSRMatrix, NDArray> CSRSortByTag(
156
    const CSRMatrix& csr, IdArray tag_array, int64_t num_tags);
157

158
template <DGLDeviceType XPU, typename IdType>
159
160
CSRMatrix CSRReorder(
    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
Da Zheng's avatar
Da Zheng committed
161

162
template <DGLDeviceType XPU, typename IdType>
163
164
COOMatrix COOReorder(
    COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
Da Zheng's avatar
Da Zheng committed
165

166
template <DGLDeviceType XPU, typename IdType>
167
168
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);

169
170
171
172
173
174
175
176
177
178
template <DGLDeviceType XPU, typename IdType, typename FloatType>
std::pair<COOMatrix, FloatArray> CSRLaborSampling(
    CSRMatrix mat,
    IdArray rows,
    int64_t num_samples,
    FloatArray prob,
    int importance_sampling,
    IdArray random_seed,
    IdArray NIDs);

179
// FloatType is the type of probability data.
180
template <DGLDeviceType XPU, typename IdType, typename DType>
181
COOMatrix CSRRowWiseSampling(
182
183
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace);
184

185
// FloatType is the type of probability data.
186
template <DGLDeviceType XPU, typename IdType, typename DType>
187
COOMatrix CSRRowWisePerEtypeSampling(
188
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
189
    const std::vector<int64_t>& num_samples,
190
191
    const std::vector<NDArray>& prob_or_mask, bool replace,
    bool rowwise_etype_sorted);
192

193
template <DGLDeviceType XPU, typename IdType>
194
195
196
COOMatrix CSRRowWiseSamplingUniform(
    CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace);

197
template <DGLDeviceType XPU, typename IdType>
198
COOMatrix CSRRowWisePerEtypeSamplingUniform(
199
200
201
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_samples, bool replace,
    bool rowwise_etype_sorted);
202

203
// FloatType is the type of weight data.
204
template <DGLDeviceType XPU, typename IdType, typename DType>
205
COOMatrix CSRRowWiseTopk(
206
    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);
207

208
template <DGLDeviceType XPU, typename IdType, typename FloatType>
209
COOMatrix CSRRowWiseSamplingBiased(
210
211
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
    FloatArray bias, bool replace);
212

213
template <DGLDeviceType XPU, typename IdType>
214
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
215
216
    const CSRMatrix& csr, int64_t num_samples, int num_trials,
    bool exclude_self_loops, bool replace, double redundancy);
217

218
// Union CSRMatrixes
219
template <DGLDeviceType XPU, typename IdType>
220
221
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);

222
template <DGLDeviceType XPU, typename IdType>
223
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr);
224

225
////////////////////////////////////////////////////////////////////////////////
Da Zheng's avatar
Da Zheng committed
226

227
template <DGLDeviceType XPU, typename IdType>
228
229
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);

230
template <DGLDeviceType XPU, typename IdType>
231
232
runtime::NDArray COOIsNonZero(
    COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
233

234
template <DGLDeviceType XPU, typename IdType>
235
236
bool COOHasDuplicate(COOMatrix coo);

237
template <DGLDeviceType XPU, typename IdType>
238
239
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row);

240
template <DGLDeviceType XPU, typename IdType>
241
242
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);

243
template <DGLDeviceType XPU, typename IdType>
244
245
std::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(
    COOMatrix coo, int64_t row);
246

247
template <DGLDeviceType XPU, typename IdType>
248
249
250
std::vector<runtime::NDArray> COOGetDataAndIndices(
    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);

251
template <DGLDeviceType XPU, typename IdType>
252
253
runtime::NDArray COOGetData(
    COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
254

255
template <DGLDeviceType XPU, typename IdType>
256
257
COOMatrix COOTranspose(COOMatrix coo);

258
template <DGLDeviceType XPU, typename IdType>
259
260
CSRMatrix COOToCSR(COOMatrix coo);

261
template <DGLDeviceType XPU, typename IdType>
262
263
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);

264
template <DGLDeviceType XPU, typename IdType>
265
266
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);

267
template <DGLDeviceType XPU, typename IdType>
268
269
COOMatrix COOSliceMatrix(
    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
270

271
template <DGLDeviceType XPU, typename IdType>
272
273
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);

274
template <DGLDeviceType XPU, typename IdType>
275
276
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);

277
template <DGLDeviceType XPU, typename IdType>
278
279
void COOSort_(COOMatrix* mat, bool sort_column);

280
template <DGLDeviceType XPU, typename IdType>
281
std::pair<bool, bool> COOIsSorted(COOMatrix coo);
282

283
template <DGLDeviceType XPU, typename IdType>
284
285
COOMatrix COORemove(COOMatrix coo, IdArray entries);

286
287
288
289
290
291
292
293
294
295
template <DGLDeviceType XPU, typename IdType, typename FloatType>
std::pair<COOMatrix, FloatArray> COOLaborSampling(
    COOMatrix mat,
    IdArray rows,
    int64_t num_samples,
    FloatArray prob,
    int importance_sampling,
    IdArray random_seed,
    IdArray NIDs);

296
// FloatType is the type of probability data.
297
template <DGLDeviceType XPU, typename IdType, typename DType>
298
COOMatrix COORowWiseSampling(
299
300
    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace);
301

302
// FloatType is the type of probability data.
303
template <DGLDeviceType XPU, typename IdType, typename DType>
304
COOMatrix COORowWisePerEtypeSampling(
305
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
306
307
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace);
308

309
template <DGLDeviceType XPU, typename IdType>
310
311
312
COOMatrix COORowWiseSamplingUniform(
    COOMatrix mat, IdArray rows, int64_t num_samples, bool replace);

313
template <DGLDeviceType XPU, typename IdType>
314
COOMatrix COORowWisePerEtypeSamplingUniform(
315
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
316
    const std::vector<int64_t>& num_samples, bool replace);
317

318
// FloatType is the type of weight data.
319
template <DGLDeviceType XPU, typename IdType, typename FloatType>
320
321
COOMatrix COORowWiseTopk(
    COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
322

323
324
///////////////////////// Graph Traverse routines //////////////////////////

325
template <DGLDeviceType XPU, typename IdType>
326
327
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);

328
template <DGLDeviceType XPU, typename IdType>
329
330
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);

331
template <DGLDeviceType XPU, typename IdType>
332
333
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);

334
template <DGLDeviceType XPU, typename IdType>
335
336
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);

337
template <DGLDeviceType XPU, typename IdType>
338
339
340
Frontiers DGLDFSLabeledEdges(
    const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
    const bool has_nontree_edge, const bool return_labels);
341

342
template <DGLDeviceType XPU, typename IdType>
343
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
344

345
346
347
348
349
}  // namespace impl
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_ARRAY_OP_H_