"src/array/vscode:/vscode.git/clone" did not exist on "31a81438fcc0965da6e0f4aded098a553a3bbf8e"
rowwise_sampling.cc 19.1 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/rowwise_sampling.cc
 * @brief rowwise sampling
5
6
 */
#include <dgl/random.h>
7

8
#include <numeric>
9

10
11
12
13
14
15
16
17
#include "./rowwise_pick.h"

namespace dgl {
namespace aten {
namespace impl {
namespace {
// Equivalent to numpy expression: array[idx[off:off + len]]
template <typename IdxType, typename FloatType>
18
19
inline FloatArray DoubleSlice(
    FloatArray array, const IdxType* idx_data, IdxType off, IdxType len) {
20
21
22
23
24
25
26
27
28
29
30
31
  const FloatType* array_data = static_cast<FloatType*>(array->data);
  FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx);
  FloatType* ret_data = static_cast<FloatType*>(ret->data);
  for (int64_t j = 0; j < len; ++j) {
    if (idx_data)
      ret_data[j] = array_data[idx_data[off + j]];
    else
      ret_data[j] = array_data[off + j];
  }
  return ret;
}

32
33
34
template <typename IdxType, typename DType>
inline NumPicksFn<IdxType> GetSamplingNumPicksFn(
    int64_t num_samples, NDArray prob_or_mask, bool replace) {
35
36
37
38
39
40
41
42
43
44
45
  NumPicksFn<IdxType> num_picks_fn = [prob_or_mask, num_samples, replace](
                                         IdxType rowid, IdxType off,
                                         IdxType len, const IdxType* col,
                                         const IdxType* data) {
    const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
    const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>();
    IdxType nnz = 0;
    for (IdxType i = off; i < off + len; ++i) {
      const IdxType eid = data ? data[i] : i;
      if (prob_or_mask_data[eid] > 0) {
        ++nnz;
46
      }
47
    }
48

49
50
51
52
53
54
    if (replace) {
      return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
    } else {
      return std::min(static_cast<IdxType>(max_num_picks), nnz);
    }
  };
55
56
57
58
59
60
  return num_picks_fn;
}

template <typename IdxType, typename DType>
inline PickFn<IdxType> GetSamplingPickFn(
    int64_t num_samples, NDArray prob_or_mask, bool replace) {
61
62
63
64
65
66
67
68
69
70
71
72
  PickFn<IdxType> pick_fn = [prob_or_mask, num_samples, replace](
                                IdxType rowid, IdxType off, IdxType len,
                                IdxType num_picks, const IdxType* col,
                                const IdxType* data, IdxType* out_idx) {
    NDArray prob_or_mask_selected =
        DoubleSlice<IdxType, DType>(prob_or_mask, data, off, len);
    RandomEngine::ThreadLocal()->Choice<IdxType, DType>(
        num_picks, prob_or_mask_selected, out_idx, replace);
    for (int64_t j = 0; j < num_picks; ++j) {
      out_idx[j] += off;
    }
  };
73
74
75
  return pick_fn;
}

76
template <typename IdxType, typename FloatType>
77
78
79
inline EtypeRangePickFn<IdxType> GetSamplingRangePickFn(
    const std::vector<int64_t>& num_samples,
    const std::vector<FloatArray>& prob, bool replace) {
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  EtypeRangePickFn<IdxType> pick_fn =
      [prob, num_samples, replace](
          IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
          const std::vector<IdxType>& et_idx,
          const std::vector<IdxType>& et_eid, const IdxType* eid,
          IdxType* out_idx) {
        const FloatArray& p = prob[cur_et];
        const FloatType* p_data = IsNullArray(p) ? nullptr : p.Ptr<FloatType>();
        FloatArray probs = FloatArray::Empty({et_len}, p->dtype, p->ctx);
        FloatType* probs_data = probs.Ptr<FloatType>();
        for (int64_t j = 0; j < et_len; ++j) {
          const IdxType cur_eid = et_eid[et_idx[et_offset + j]];
          probs_data[j] = p_data ? p_data[cur_eid] : static_cast<FloatType>(1.);
        }
94

95
96
97
        RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
            num_samples[cur_et], probs, out_idx, replace);
      };
98
99
100
  return pick_fn;
}

101
102
103
template <typename IdxType>
inline NumPicksFn<IdxType> GetSamplingUniformNumPicksFn(
    int64_t num_samples, bool replace) {
104
105
106
107
108
109
110
111
112
113
114
  NumPicksFn<IdxType> num_picks_fn = [num_samples, replace](
                                         IdxType rowid, IdxType off,
                                         IdxType len, const IdxType* col,
                                         const IdxType* data) {
    const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
    if (replace) {
      return static_cast<IdxType>(len == 0 ? 0 : max_num_picks);
    } else {
      return std::min(static_cast<IdxType>(max_num_picks), len);
    }
  };
115
116
117
  return num_picks_fn;
}

118
119
120
template <typename IdxType>
inline PickFn<IdxType> GetSamplingUniformPickFn(
    int64_t num_samples, bool replace) {
121
122
123
124
125
126
127
128
129
130
  PickFn<IdxType> pick_fn = [num_samples, replace](
                                IdxType rowid, IdxType off, IdxType len,
                                IdxType num_picks, const IdxType* col,
                                const IdxType* data, IdxType* out_idx) {
    RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
        num_picks, len, out_idx, replace);
    for (int64_t j = 0; j < num_picks; ++j) {
      out_idx[j] += off;
    }
  };
131
132
  return pick_fn;
}
133

134
template <typename IdxType>
135
inline EtypeRangePickFn<IdxType> GetSamplingUniformRangePickFn(
136
    const std::vector<int64_t>& num_samples, bool replace) {
137
138
139
140
141
142
143
144
145
  EtypeRangePickFn<IdxType> pick_fn =
      [num_samples, replace](
          IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
          const std::vector<IdxType>& et_idx,
          const std::vector<IdxType>& et_eid, const IdxType* data,
          IdxType* out_idx) {
        RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
            num_samples[cur_et], et_len, out_idx, replace);
      };
146
147
148
  return pick_fn;
}

149
150
151
template <typename IdxType, typename FloatType>
inline NumPicksFn<IdxType> GetSamplingBiasedNumPicksFn(
    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
152
153
154
155
156
157
158
159
160
161
162
163
  NumPicksFn<IdxType> num_picks_fn = [num_samples, split, bias, replace](
                                         IdxType rowid, IdxType off,
                                         IdxType len, const IdxType* col,
                                         const IdxType* data) {
    const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
    const int64_t num_tags = split->shape[1] - 1;
    const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
    const FloatType* bias_data = bias.Ptr<FloatType>();
    IdxType nnz = 0;
    for (int64_t j = 0; j < num_tags; ++j) {
      if (bias_data[j] > 0) {
        nnz += tag_offset[j + 1] - tag_offset[j];
164
      }
165
    }
166

167
168
169
170
171
172
    if (replace) {
      return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
    } else {
      return std::min(static_cast<IdxType>(max_num_picks), nnz);
    }
  };
173
174
175
  return num_picks_fn;
}

176
177
178
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingBiasedPickFn(
    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
179
180
181
182
183
  PickFn<IdxType> pick_fn = [num_samples, split, bias, replace](
                                IdxType rowid, IdxType off, IdxType len,
                                IdxType num_picks, const IdxType* col,
                                const IdxType* data, IdxType* out_idx) {
    const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
184
    RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>(
185
        num_picks, tag_offset, bias, out_idx, replace);
186
    for (int64_t j = 0; j < num_picks; ++j) {
187
188
189
190
191
192
      out_idx[j] += off;
    }
  };
  return pick_fn;
}

193
194
195
196
}  // namespace

/////////////////////////////// CSR ///////////////////////////////

197
template <DGLDeviceType XPU, typename IdxType, typename DType>
198
199
200
COOMatrix CSRRowWiseSampling(
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace) {
201
202
203
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  CHECK(prob_or_mask.defined());
204
205
206
207
  auto num_picks_fn =
      GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
  auto pick_fn =
      GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
208
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
209
210
}

211
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, float>(
212
    CSRMatrix, IdArray, int64_t, NDArray, bool);
213
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, float>(
214
    CSRMatrix, IdArray, int64_t, NDArray, bool);
215
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, double>(
216
    CSRMatrix, IdArray, int64_t, NDArray, bool);
217
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, double>(
218
219
220
221
222
223
224
225
226
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, int8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, int8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
227

228
229
230
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling(
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
231
232
233
234
235
236
237
238
239
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace,
    bool rowwise_etype_sorted) {
  CHECK(prob_or_mask.size() == num_samples.size())
      << "the number of probability tensors does not match the number of edge "
         "types.";
  for (auto& p : prob_or_mask) CHECK(p.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
240
  return CSRRowWisePerEtypePick<IdxType, DType>(
241
242
      mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted,
      pick_fn, prob_or_mask);
243
244
}

245
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
246
247
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
248
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
249
250
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
251
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
252
253
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
254
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
255
256
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
257
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
258
259
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
260
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
261
262
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
263
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
264
265
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
266
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
267
268
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
269

270
template <DGLDeviceType XPU, typename IdxType>
271
272
COOMatrix CSRRowWiseSamplingUniform(
    CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace) {
273
274
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
275
276
  auto num_picks_fn =
      GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
277
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
278
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
279
280
}

281
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
282
    CSRMatrix, IdArray, int64_t, bool);
283
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
284
285
    CSRMatrix, IdArray, int64_t, bool);

286
template <DGLDeviceType XPU, typename IdxType>
287
288
COOMatrix CSRRowWisePerEtypeSamplingUniform(
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
289
290
    const std::vector<int64_t>& num_samples, bool replace,
    bool rowwise_etype_sorted) {
291
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
292
  return CSRRowWisePerEtypePick<IdxType, float>(
293
294
      mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted,
      pick_fn, {});
295
296
}

297
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
298
299
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, bool, bool);
300
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
301
302
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, bool, bool);
303

304
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
305
COOMatrix CSRRowWiseSamplingBiased(
306
307
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
    FloatArray bias, bool replace) {
308
309
310
311
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  auto num_picks_fn = GetSamplingBiasedNumPicksFn<IdxType, FloatType>(
      num_samples, tag_offset, bias, replace);
312
313
  auto pick_fn = GetSamplingBiasedPickFn<IdxType, FloatType>(
      num_samples, tag_offset, bias, replace);
314
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
315
316
}

317
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
318
    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
319

320
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(
321
    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
322

323
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(
324
    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
325

326
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
327
    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
328

329
330
/////////////////////////////// COO ///////////////////////////////

331
template <DGLDeviceType XPU, typename IdxType, typename DType>
332
333
334
COOMatrix COORowWiseSampling(
    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace) {
335
336
337
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  CHECK(prob_or_mask.defined());
338
339
340
341
  auto num_picks_fn =
      GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
  auto pick_fn =
      GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
342
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
343
344
}

345
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, float>(
346
    COOMatrix, IdArray, int64_t, NDArray, bool);
347
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, float>(
348
    COOMatrix, IdArray, int64_t, NDArray, bool);
349
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, double>(
350
    COOMatrix, IdArray, int64_t, NDArray, bool);
351
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, double>(
352
353
354
355
356
357
358
359
360
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, int8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, int8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, uint8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, uint8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
361

362
363
364
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWisePerEtypeSampling(
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
365
366
367
368
369
370
371
372
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace) {
  CHECK(prob_or_mask.size() == num_samples.size())
      << "the number of probability tensors do not match the number of edge "
         "types.";
  for (auto& p : prob_or_mask) CHECK(p.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
373
374
  return COORowWisePerEtypePick<IdxType, DType>(
      mat, rows, eid2etype_offset, num_samples, replace, pick_fn, prob_or_mask);
375
376
}

377
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
378
379
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
380
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
381
382
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
383
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
384
385
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
386
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
387
388
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
389
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
390
391
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
392
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
393
394
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
395
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
396
397
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
398
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
399
400
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
401

402
template <DGLDeviceType XPU, typename IdxType>
403
404
COOMatrix COORowWiseSamplingUniform(
    COOMatrix mat, IdArray rows, int64_t num_samples, bool replace) {
405
406
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
407
408
  auto num_picks_fn =
      GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
409
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
410
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
411
412
}

413
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int32_t>(
414
    COOMatrix, IdArray, int64_t, bool);
415
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>(
416
417
    COOMatrix, IdArray, int64_t, bool);

418
template <DGLDeviceType XPU, typename IdxType>
419
420
421
COOMatrix COORowWisePerEtypeSamplingUniform(
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_samples, bool replace) {
422
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
423
424
  return COORowWisePerEtypePick<IdxType, float>(
      mat, rows, eid2etype_offset, num_samples, replace, pick_fn, {});
425
426
}

427
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
428
429
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, bool);
430
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
431
432
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, bool);
433

434
435
436
}  // namespace impl
}  // namespace aten
}  // namespace dgl