array_op.h 12.5 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019-2022 by Contributors
3
4
 * @file array/array_op.h
 * @brief Array operator templates
5
6
7
8
9
 */
#ifndef DGL_ARRAY_ARRAY_OP_H_
#define DGL_ARRAY_ARRAY_OP_H_

#include <dgl/array.h>
10
#include <dgl/graph_traversal.h>
11

12
13
#include <tuple>
#include <utility>
14
#include <vector>
15
16
17
18
19

namespace dgl {
namespace aten {
namespace impl {

20
21
template <DGLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DGLContext ctx);
22

23
24
template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DGLContext ctx);
25

26
template <DGLDeviceType XPU, typename IdType>
27
28
IdArray AsNumBits(IdArray arr, uint8_t bits);

29
template <DGLDeviceType XPU, typename IdType, typename Op>
30
31
IdArray BinaryElewise(IdArray lhs, IdArray rhs);

32
template <DGLDeviceType XPU, typename IdType, typename Op>
33
34
IdArray BinaryElewise(IdArray lhs, IdType rhs);

35
template <DGLDeviceType XPU, typename IdType, typename Op>
36
37
IdArray BinaryElewise(IdType lhs, IdArray rhs);

38
template <DGLDeviceType XPU, typename IdType, typename Op>
39
40
IdArray UnaryElewise(IdArray array);

41
template <DGLDeviceType XPU, typename DType, typename IdType>
42
NDArray IndexSelect(NDArray array, IdArray index);
43

44
template <DGLDeviceType XPU, typename DType>
45
DType IndexSelect(NDArray array, int64_t index);
46

47
template <DGLDeviceType XPU, typename DType>
Jinjing Zhou's avatar
Jinjing Zhou committed
48
49
IdArray NonZero(BoolArray bool_arr);

50
51
52
template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(NDArray array);

53
template <DGLDeviceType XPU, typename DType>
54
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits);
55

56
template <DGLDeviceType XPU, typename DType, typename IdType>
57
58
NDArray Scatter(NDArray array, IdArray indices);

59
template <DGLDeviceType XPU, typename DType, typename IdType>
60
61
void Scatter_(IdArray index, NDArray value, NDArray out);

62
template <DGLDeviceType XPU, typename DType, typename IdType>
63
64
NDArray Repeat(NDArray array, IdArray repeats);

65
template <DGLDeviceType XPU, typename IdType>
66
67
IdArray Relabel_(const std::vector<IdArray>& arrays);

68
template <DGLDeviceType XPU, typename IdType>
69
70
NDArray Concat(const std::vector<IdArray>& arrays);

71
template <DGLDeviceType XPU, typename DType>
72
73
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value);

74
template <DGLDeviceType XPU, typename DType, typename IdType>
75
76
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);

77
template <DGLDeviceType XPU, typename IdType>
78
79
IdArray CumSum(IdArray array, bool prepend_zero);

80
81
// sparse arrays

82
template <DGLDeviceType XPU, typename IdType>
83
84
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);

85
template <DGLDeviceType XPU, typename IdType>
86
87
runtime::NDArray CSRIsNonZero(
    CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
88

89
template <DGLDeviceType XPU, typename IdType>
90
91
bool CSRHasDuplicate(CSRMatrix csr);

92
template <DGLDeviceType XPU, typename IdType>
93
94
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row);

95
template <DGLDeviceType XPU, typename IdType>
96
97
runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row);

98
template <DGLDeviceType XPU, typename IdType>
99
100
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row);

101
template <DGLDeviceType XPU, typename IdType>
102
103
runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);

104
template <DGLDeviceType XPU, typename IdType>
105
106
bool CSRIsSorted(CSRMatrix csr);

107
template <DGLDeviceType XPU, typename IdType, typename DType>
108
runtime::NDArray CSRGetData(
109
110
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
    bool return_eids, runtime::NDArray weights, DType filler);
111

112
template <DGLDeviceType XPU, typename IdType, typename DType>
113
114
115
runtime::NDArray CSRGetData(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
    runtime::NDArray weights, DType filler) {
116
117
  return CSRGetData<XPU, IdType, DType>(
      csr, rows, cols, false, weights, filler);
118
119
}

120
template <DGLDeviceType XPU, typename IdType>
121
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
122
123
  return CSRGetData<XPU, IdType, IdType>(
      csr, rows, cols, true, NullArray(rows->dtype), -1);
124
}
125

126
template <DGLDeviceType XPU, typename IdType>
127
128
129
std::vector<runtime::NDArray> CSRGetDataAndIndices(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);

130
template <DGLDeviceType XPU, typename IdType>
131
132
133
CSRMatrix CSRTranspose(CSRMatrix csr);

// Convert CSR to COO
134
template <DGLDeviceType XPU, typename IdType>
135
136
137
COOMatrix CSRToCOO(CSRMatrix csr);

// Convert CSR to COO using data array as order
138
template <DGLDeviceType XPU, typename IdType>
139
140
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr);

141
template <DGLDeviceType XPU, typename IdType>
142
143
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);

144
template <DGLDeviceType XPU, typename IdType>
145
146
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);

147
template <DGLDeviceType XPU, typename IdType>
148
149
CSRMatrix CSRSliceMatrix(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
150

151
template <DGLDeviceType XPU, typename IdType>
152
153
void CSRSort_(CSRMatrix* csr);

154
template <DGLDeviceType XPU, typename IdType, typename TagType>
155
std::pair<CSRMatrix, NDArray> CSRSortByTag(
156
    const CSRMatrix& csr, IdArray tag_array, int64_t num_tags);
157

158
template <DGLDeviceType XPU, typename IdType>
159
160
CSRMatrix CSRReorder(
    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
Da Zheng's avatar
Da Zheng committed
161

162
template <DGLDeviceType XPU, typename IdType>
163
164
COOMatrix COOReorder(
    COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
Da Zheng's avatar
Da Zheng committed
165

166
template <DGLDeviceType XPU, typename IdType>
167
168
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);

169
170
template <DGLDeviceType XPU, typename IdType, typename FloatType>
std::pair<COOMatrix, FloatArray> CSRLaborSampling(
171
172
    CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,
    int importance_sampling, IdArray random_seed, float seed2_contribution,
173
174
    IdArray NIDs);

175
// FloatType is the type of probability data.
176
template <DGLDeviceType XPU, typename IdType, typename DType>
177
COOMatrix CSRRowWiseSampling(
178
179
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace);
180

181
182
183
184
185
186
187
188
// FloatType is the type of probability data.
template <
    DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
    CSRMatrix mat, IdArray rows, IdArray seed_mapping,
    std::vector<IdxType>* new_seed_nodes, int64_t num_samples,
    NDArray prob_or_mask, bool replace);

189
// FloatType is the type of probability data.
190
template <DGLDeviceType XPU, typename IdType, typename DType>
191
COOMatrix CSRRowWisePerEtypeSampling(
192
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
193
    const std::vector<int64_t>& num_samples,
194
195
    const std::vector<NDArray>& prob_or_mask, bool replace,
    bool rowwise_etype_sorted);
196

197
template <DGLDeviceType XPU, typename IdType>
198
199
200
COOMatrix CSRRowWiseSamplingUniform(
    CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace);

201
202
203
204
205
template <DGLDeviceType XPU, typename IdType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingUniformFused(
    CSRMatrix mat, IdArray rows, IdArray seed_mapping,
    std::vector<IdType>* new_seed_nodes, int64_t num_samples, bool replace);

206
template <DGLDeviceType XPU, typename IdType>
207
COOMatrix CSRRowWisePerEtypeSamplingUniform(
208
209
210
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_samples, bool replace,
    bool rowwise_etype_sorted);
211

212
// FloatType is the type of weight data.
213
template <DGLDeviceType XPU, typename IdType, typename DType>
214
COOMatrix CSRRowWiseTopk(
215
    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);
216

217
template <DGLDeviceType XPU, typename IdType, typename FloatType>
218
COOMatrix CSRRowWiseSamplingBiased(
219
220
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
    FloatArray bias, bool replace);
221

222
template <DGLDeviceType XPU, typename IdType>
223
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
224
225
    const CSRMatrix& csr, int64_t num_samples, int num_trials,
    bool exclude_self_loops, bool replace, double redundancy);
226

227
// Union CSRMatrixes
228
template <DGLDeviceType XPU, typename IdType>
229
230
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);

231
template <DGLDeviceType XPU, typename IdType>
232
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr);
233

234
////////////////////////////////////////////////////////////////////////////////
Da Zheng's avatar
Da Zheng committed
235

236
template <DGLDeviceType XPU, typename IdType>
237
238
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);

239
template <DGLDeviceType XPU, typename IdType>
240
241
runtime::NDArray COOIsNonZero(
    COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
242

243
template <DGLDeviceType XPU, typename IdType>
244
245
bool COOHasDuplicate(COOMatrix coo);

246
template <DGLDeviceType XPU, typename IdType>
247
248
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row);

249
template <DGLDeviceType XPU, typename IdType>
250
251
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);

252
template <DGLDeviceType XPU, typename IdType>
253
254
std::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(
    COOMatrix coo, int64_t row);
255

256
template <DGLDeviceType XPU, typename IdType>
257
258
259
std::vector<runtime::NDArray> COOGetDataAndIndices(
    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);

260
template <DGLDeviceType XPU, typename IdType>
261
262
runtime::NDArray COOGetData(
    COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
263

264
template <DGLDeviceType XPU, typename IdType>
265
266
COOMatrix COOTranspose(COOMatrix coo);

267
template <DGLDeviceType XPU, typename IdType>
268
269
CSRMatrix COOToCSR(COOMatrix coo);

270
template <DGLDeviceType XPU, typename IdType>
271
272
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);

273
template <DGLDeviceType XPU, typename IdType>
274
275
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);

276
template <DGLDeviceType XPU, typename IdType>
277
278
COOMatrix COOSliceMatrix(
    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
279

280
template <DGLDeviceType XPU, typename IdType>
281
282
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);

283
template <DGLDeviceType XPU, typename IdType>
284
285
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);

286
template <DGLDeviceType XPU, typename IdType>
287
288
void COOSort_(COOMatrix* mat, bool sort_column);

289
template <DGLDeviceType XPU, typename IdType>
290
std::pair<bool, bool> COOIsSorted(COOMatrix coo);
291

292
template <DGLDeviceType XPU, typename IdType>
293
294
COOMatrix COORemove(COOMatrix coo, IdArray entries);

295
296
template <DGLDeviceType XPU, typename IdType, typename FloatType>
std::pair<COOMatrix, FloatArray> COOLaborSampling(
297
298
    COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,
    int importance_sampling, IdArray random_seed, float seed2_contribution,
299
300
    IdArray NIDs);

301
// FloatType is the type of probability data.
302
template <DGLDeviceType XPU, typename IdType, typename DType>
303
COOMatrix COORowWiseSampling(
304
305
    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace);
306

307
// FloatType is the type of probability data.
308
template <DGLDeviceType XPU, typename IdType, typename DType>
309
COOMatrix COORowWisePerEtypeSampling(
310
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
311
312
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace);
313

314
template <DGLDeviceType XPU, typename IdType>
315
316
317
COOMatrix COORowWiseSamplingUniform(
    COOMatrix mat, IdArray rows, int64_t num_samples, bool replace);

318
template <DGLDeviceType XPU, typename IdType>
319
COOMatrix COORowWisePerEtypeSamplingUniform(
320
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
321
    const std::vector<int64_t>& num_samples, bool replace);
322

323
// FloatType is the type of weight data.
324
template <DGLDeviceType XPU, typename IdType, typename FloatType>
325
326
COOMatrix COORowWiseTopk(
    COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
327

328
329
///////////////////////// Graph Traverse routines //////////////////////////

330
template <DGLDeviceType XPU, typename IdType>
331
332
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);

333
template <DGLDeviceType XPU, typename IdType>
334
335
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);

336
template <DGLDeviceType XPU, typename IdType>
337
338
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);

339
template <DGLDeviceType XPU, typename IdType>
340
341
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);

342
template <DGLDeviceType XPU, typename IdType>
343
344
345
Frontiers DGLDFSLabeledEdges(
    const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
    const bool has_nontree_edge, const bool return_labels);
346

347
template <DGLDeviceType XPU, typename IdType>
348
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
349

350
351
352
353
354
}  // namespace impl
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_ARRAY_OP_H_