"tests/python/vscode:/vscode.git/clone" did not exist on "0742b85bb08147da04ee41681dca3f3ae056fd0b"
array_op.h 11.4 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file array/array_op.h
 * @brief Array operator templates
5
6
7
8
9
 */
#ifndef DGL_ARRAY_ARRAY_OP_H_
#define DGL_ARRAY_ARRAY_OP_H_

#include <dgl/array.h>
10
#include <dgl/graph_traversal.h>
11

12
13
#include <tuple>
#include <utility>
14
#include <vector>
15
16
17
18
19

namespace dgl {
namespace aten {
namespace impl {

20
21
template <DGLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DGLContext ctx);
22

23
24
template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DGLContext ctx);
25

26
template <DGLDeviceType XPU, typename IdType>
27
28
IdArray AsNumBits(IdArray arr, uint8_t bits);

29
template <DGLDeviceType XPU, typename IdType, typename Op>
30
31
IdArray BinaryElewise(IdArray lhs, IdArray rhs);

32
template <DGLDeviceType XPU, typename IdType, typename Op>
33
34
IdArray BinaryElewise(IdArray lhs, IdType rhs);

35
template <DGLDeviceType XPU, typename IdType, typename Op>
36
37
IdArray BinaryElewise(IdType lhs, IdArray rhs);

38
template <DGLDeviceType XPU, typename IdType, typename Op>
39
40
IdArray UnaryElewise(IdArray array);

41
template <DGLDeviceType XPU, typename DType, typename IdType>
42
NDArray IndexSelect(NDArray array, IdArray index);
43

44
template <DGLDeviceType XPU, typename DType>
45
DType IndexSelect(NDArray array, int64_t index);
46

47
template <DGLDeviceType XPU, typename DType>
Jinjing Zhou's avatar
Jinjing Zhou committed
48
49
IdArray NonZero(BoolArray bool_arr);

50
51
52
template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(NDArray array);

53
template <DGLDeviceType XPU, typename DType>
54
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits);
55

56
template <DGLDeviceType XPU, typename DType, typename IdType>
57
58
NDArray Scatter(NDArray array, IdArray indices);

59
template <DGLDeviceType XPU, typename DType, typename IdType>
60
61
void Scatter_(IdArray index, NDArray value, NDArray out);

62
template <DGLDeviceType XPU, typename DType, typename IdType>
63
64
NDArray Repeat(NDArray array, IdArray repeats);

65
template <DGLDeviceType XPU, typename IdType>
66
67
IdArray Relabel_(const std::vector<IdArray>& arrays);

68
template <DGLDeviceType XPU, typename IdType>
69
70
NDArray Concat(const std::vector<IdArray>& arrays);

71
template <DGLDeviceType XPU, typename DType>
72
73
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value);

74
template <DGLDeviceType XPU, typename DType, typename IdType>
75
76
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);

77
template <DGLDeviceType XPU, typename IdType>
78
79
IdArray CumSum(IdArray array, bool prepend_zero);

80
81
// sparse arrays

82
template <DGLDeviceType XPU, typename IdType>
83
84
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);

85
template <DGLDeviceType XPU, typename IdType>
86
87
runtime::NDArray CSRIsNonZero(
    CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
88

89
template <DGLDeviceType XPU, typename IdType>
90
91
bool CSRHasDuplicate(CSRMatrix csr);

92
template <DGLDeviceType XPU, typename IdType>
93
94
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row);

95
template <DGLDeviceType XPU, typename IdType>
96
97
runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row);

98
template <DGLDeviceType XPU, typename IdType>
99
100
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row);

101
template <DGLDeviceType XPU, typename IdType>
102
103
runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);

104
template <DGLDeviceType XPU, typename IdType>
105
106
bool CSRIsSorted(CSRMatrix csr);

107
template <DGLDeviceType XPU, typename IdType, typename DType>
108
runtime::NDArray CSRGetData(
109
110
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
    bool return_eids, runtime::NDArray weights, DType filler);
111

112
template <DGLDeviceType XPU, typename IdType, typename DType>
113
114
115
runtime::NDArray CSRGetData(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
    runtime::NDArray weights, DType filler) {
116
117
  return CSRGetData<XPU, IdType, DType>(
      csr, rows, cols, false, weights, filler);
118
119
}

120
template <DGLDeviceType XPU, typename IdType>
121
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
122
123
  return CSRGetData<XPU, IdType, IdType>(
      csr, rows, cols, true, NullArray(rows->dtype), -1);
124
}
125

126
template <DGLDeviceType XPU, typename IdType>
127
128
129
std::vector<runtime::NDArray> CSRGetDataAndIndices(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);

130
template <DGLDeviceType XPU, typename IdType>
131
132
133
CSRMatrix CSRTranspose(CSRMatrix csr);

// Convert CSR to COO
134
template <DGLDeviceType XPU, typename IdType>
135
136
137
COOMatrix CSRToCOO(CSRMatrix csr);

// Convert CSR to COO using data array as order
138
template <DGLDeviceType XPU, typename IdType>
139
140
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr);

141
template <DGLDeviceType XPU, typename IdType>
142
143
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);

144
template <DGLDeviceType XPU, typename IdType>
145
146
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);

147
template <DGLDeviceType XPU, typename IdType>
148
149
CSRMatrix CSRSliceMatrix(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
150

151
template <DGLDeviceType XPU, typename IdType>
152
153
void CSRSort_(CSRMatrix* csr);

154
template <DGLDeviceType XPU, typename IdType, typename TagType>
155
std::pair<CSRMatrix, NDArray> CSRSortByTag(
156
    const CSRMatrix& csr, IdArray tag_array, int64_t num_tags);
157

158
template <DGLDeviceType XPU, typename IdType>
159
160
CSRMatrix CSRReorder(
    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
Da Zheng's avatar
Da Zheng committed
161

162
template <DGLDeviceType XPU, typename IdType>
163
164
COOMatrix COOReorder(
    COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
Da Zheng's avatar
Da Zheng committed
165

166
template <DGLDeviceType XPU, typename IdType>
167
168
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);

169
// FloatType is the type of probability data.
170
template <DGLDeviceType XPU, typename IdType, typename DType>
171
COOMatrix CSRRowWiseSampling(
172
173
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace);
174

175
// FloatType is the type of probability data.
176
template <DGLDeviceType XPU, typename IdType, typename DType>
177
COOMatrix CSRRowWisePerEtypeSampling(
178
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
179
    const std::vector<int64_t>& num_samples,
180
181
    const std::vector<NDArray>& prob_or_mask, bool replace,
    bool rowwise_etype_sorted);
182

183
template <DGLDeviceType XPU, typename IdType>
184
185
186
COOMatrix CSRRowWiseSamplingUniform(
    CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace);

187
template <DGLDeviceType XPU, typename IdType>
188
COOMatrix CSRRowWisePerEtypeSamplingUniform(
189
190
191
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_samples, bool replace,
    bool rowwise_etype_sorted);
192

193
// FloatType is the type of weight data.
194
template <DGLDeviceType XPU, typename IdType, typename DType>
195
COOMatrix CSRRowWiseTopk(
196
    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);
197

198
template <DGLDeviceType XPU, typename IdType, typename FloatType>
199
COOMatrix CSRRowWiseSamplingBiased(
200
201
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
    FloatArray bias, bool replace);
202

203
template <DGLDeviceType XPU, typename IdType>
204
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
205
206
    const CSRMatrix& csr, int64_t num_samples, int num_trials,
    bool exclude_self_loops, bool replace, double redundancy);
207

208
// Union CSRMatrixes
209
template <DGLDeviceType XPU, typename IdType>
210
211
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);

212
template <DGLDeviceType XPU, typename IdType>
213
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr);
214

215
////////////////////////////////////////////////////////////////////////////////
Da Zheng's avatar
Da Zheng committed
216

217
template <DGLDeviceType XPU, typename IdType>
218
219
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);

220
template <DGLDeviceType XPU, typename IdType>
221
222
runtime::NDArray COOIsNonZero(
    COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
223

224
template <DGLDeviceType XPU, typename IdType>
225
226
bool COOHasDuplicate(COOMatrix coo);

227
template <DGLDeviceType XPU, typename IdType>
228
229
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row);

230
template <DGLDeviceType XPU, typename IdType>
231
232
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);

233
template <DGLDeviceType XPU, typename IdType>
234
235
std::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(
    COOMatrix coo, int64_t row);
236

237
template <DGLDeviceType XPU, typename IdType>
238
239
240
std::vector<runtime::NDArray> COOGetDataAndIndices(
    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);

241
template <DGLDeviceType XPU, typename IdType>
242
243
runtime::NDArray COOGetData(
    COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
244

245
template <DGLDeviceType XPU, typename IdType>
246
247
COOMatrix COOTranspose(COOMatrix coo);

248
template <DGLDeviceType XPU, typename IdType>
249
250
CSRMatrix COOToCSR(COOMatrix coo);

251
template <DGLDeviceType XPU, typename IdType>
252
253
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);

254
template <DGLDeviceType XPU, typename IdType>
255
256
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);

257
template <DGLDeviceType XPU, typename IdType>
258
259
COOMatrix COOSliceMatrix(
    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
260

261
template <DGLDeviceType XPU, typename IdType>
262
263
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);

264
template <DGLDeviceType XPU, typename IdType>
265
266
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);

267
template <DGLDeviceType XPU, typename IdType>
268
269
void COOSort_(COOMatrix* mat, bool sort_column);

270
template <DGLDeviceType XPU, typename IdType>
271
std::pair<bool, bool> COOIsSorted(COOMatrix coo);
272

273
template <DGLDeviceType XPU, typename IdType>
274
275
COOMatrix COORemove(COOMatrix coo, IdArray entries);

276
// FloatType is the type of probability data.
277
template <DGLDeviceType XPU, typename IdType, typename DType>
278
COOMatrix COORowWiseSampling(
279
280
    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace);
281

282
// FloatType is the type of probability data.
283
template <DGLDeviceType XPU, typename IdType, typename DType>
284
COOMatrix COORowWisePerEtypeSampling(
285
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
286
287
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace);
288

289
template <DGLDeviceType XPU, typename IdType>
290
291
292
COOMatrix COORowWiseSamplingUniform(
    COOMatrix mat, IdArray rows, int64_t num_samples, bool replace);

293
template <DGLDeviceType XPU, typename IdType>
294
COOMatrix COORowWisePerEtypeSamplingUniform(
295
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
296
    const std::vector<int64_t>& num_samples, bool replace);
297

298
// FloatType is the type of weight data.
299
template <DGLDeviceType XPU, typename IdType, typename FloatType>
300
301
COOMatrix COORowWiseTopk(
    COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
302

303
304
///////////////////////// Graph Traverse routines //////////////////////////

305
template <DGLDeviceType XPU, typename IdType>
306
307
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);

308
template <DGLDeviceType XPU, typename IdType>
309
310
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);

311
template <DGLDeviceType XPU, typename IdType>
312
313
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);

314
template <DGLDeviceType XPU, typename IdType>
315
316
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);

317
template <DGLDeviceType XPU, typename IdType>
318
319
320
Frontiers DGLDFSLabeledEdges(
    const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
    const bool has_nontree_edge, const bool return_labels);
321

322
template <DGLDeviceType XPU, typename IdType>
323
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
324

325
326
327
328
329
}  // namespace impl
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_ARRAY_OP_H_