frequency_hashmap.hip 18.8 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2021 by Contributors
5
6
 * @file graph/sampling/frequency_hashmap.cu
 * @brief frequency hashmap - used to select top-k frequency edges of each node
7
8
9
 */

#include <algorithm>
sangwzh's avatar
sangwzh committed
10
#include <hipcub/hipcub.hpp>  // NOLINT
11
12
#include <tuple>
#include <utility>
13

14
#include "../../../array/cuda/atomic.cuh"
15
#include "../../../runtime/cuda/cuda_common.h"
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include "frequency_hashmap.cuh"

namespace dgl {

namespace sampling {

namespace impl {

namespace {

int64_t _table_size(const int64_t num, const int64_t scale) {
  /**
   * Calculate the number of buckets in the hashtable. To guarantee we can
   * fill the hashtable in the worst case, we must use a number of buckets which
   * is a power of two.
   * https://en.wikipedia.org/wiki/Quadratic_probing#Limitations
   */
  const int64_t next_pow2 = 1 << static_cast<int64_t>(1 + std::log2(num >> 1));
  return next_pow2 << scale;
}

37
template <typename IdxType, int BLOCK_SIZE, int TILE_SIZE>
38
39
__global__ void _init_edge_table(void *edge_hashmap, int64_t edges_len) {
  using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem;
40
  auto edge_hashmap_t = static_cast<EdgeItem *>(edge_hashmap);
41
42
43
44
45
46
47
48
49
50
51
52
  int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
  int64_t last_idx = start_idx + TILE_SIZE;
#pragma unroll(4)
  for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {
    if (idx < edges_len) {
      EdgeItem *edge = (edge_hashmap_t + idx);
      edge->src = static_cast<IdxType>(-1);
      edge->cnt = static_cast<IdxType>(0);
    }
  }
}

53
54
55
56
57
template <typename IdxType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _count_frequency(
    const IdxType *src_data, const int64_t num_edges,
    const int64_t num_edges_per_node, IdxType *edge_blocks_prefix,
    bool *is_first_position, DeviceEdgeHashmap<IdxType> device_edge_hashmap) {
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
  int64_t last_idx = start_idx + TILE_SIZE;

  IdxType count = 0;
  for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {
    if (idx < num_edges) {
      IdxType src = src_data[idx];
      if (src == static_cast<IdxType>(-1)) {
        continue;
      }
      IdxType dst_idx = (idx / num_edges_per_node);
      if (device_edge_hashmap.InsertEdge(src, dst_idx) == 0) {
        is_first_position[idx] = true;
        ++count;
      }
    }
  }

sangwzh's avatar
sangwzh committed
76
  using BlockReduce = typename hipcub::BlockReduce<IdxType, BLOCK_SIZE>;
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
  __shared__ typename BlockReduce::TempStorage temp_space;

  count = BlockReduce(temp_space).Sum(count);
  if (threadIdx.x == 0) {
    edge_blocks_prefix[blockIdx.x] = count;
    if (blockIdx.x == 0) {
      edge_blocks_prefix[gridDim.x] = 0;
    }
  }
}

/**
 * This structure is used with cub's block-level prefixscan in order to
 * keep a running sum as items are iteratively processed.
 */
template <typename T>
struct BlockPrefixCallbackOp {
  T _running_total;

  __device__ BlockPrefixCallbackOp(const T running_total)
      : _running_total(running_total) {}

  __device__ T operator()(const T block_aggregate) {
    const T old_prefix = _running_total;
    _running_total += block_aggregate;
    return old_prefix;
  }
};

106
107
108
109
110
template <typename IdxType, typename Idx64Type, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _compact_frequency(
    const IdxType *src_data, const IdxType *dst_data, const int64_t num_edges,
    const int64_t num_edges_per_node, const IdxType *edge_blocks_prefix,
    const bool *is_first_position, IdxType *num_unique_each_node,
111
112
113
114
115
116
    IdxType *unique_src_edges, Idx64Type *unique_frequency,
    DeviceEdgeHashmap<IdxType> device_edge_hashmap) {
  int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
  int64_t last_idx = start_idx + TILE_SIZE;
  const IdxType block_offset = edge_blocks_prefix[blockIdx.x];

sangwzh's avatar
sangwzh committed
117
  using BlockScan = typename hipcub::BlockScan<IdxType, BLOCK_SIZE>;
118
119
120
121
122
123
124
125
126
  __shared__ typename BlockScan::TempStorage temp_space;
  BlockPrefixCallbackOp<IdxType> prefix_op(0);

  for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {
    IdxType flag = 0;
    if (idx < num_edges) {
      IdxType src = src_data[idx];
      IdxType dst_idx = (idx / num_edges_per_node);
      if (idx % num_edges_per_node == 0) {
127
128
        num_unique_each_node[dst_idx] =
            device_edge_hashmap.GetDstCount(dst_idx);
129
130
131
132
133
134
135
136
137
      }
      if (is_first_position[idx] == true) {
        flag = 1;
      }
      BlockScan(temp_space).ExclusiveSum(flag, flag, prefix_op);
      __syncthreads();
      if (is_first_position[idx] == true) {
        const IdxType pos = (block_offset + flag);
        unique_src_edges[pos] = src;
138
139
140
141
142
143
        if (sizeof(IdxType) != sizeof(Idx64Type) &&
            sizeof(IdxType) == 4) {  // if IdxType is a 32-bit data
          unique_frequency[pos] =
              ((static_cast<Idx64Type>(num_edges / num_edges_per_node - dst_idx)
                << 32) |
               device_edge_hashmap.GetEdgeCount(src, dst_idx));
144
        } else {
145
146
          unique_frequency[pos] =
              device_edge_hashmap.GetEdgeCount(src, dst_idx);
147
148
149
150
151
152
        }
      }
    }
  }
}

153
154
155
156
template <typename IdxType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _get_pick_num(
    IdxType *num_unique_each_node, const int64_t num_pick,
    const int64_t num_dst_nodes) {
157
158
159
160
161
162
163
164
165
166
167
  int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
  int64_t last_idx = start_idx + TILE_SIZE;
#pragma unroll(4)
  for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {
    if (idx < num_dst_nodes) {
      IdxType &num_unique = num_unique_each_node[idx];
      num_unique = min(num_unique, static_cast<IdxType>(num_pick));
    }
  }
}

168
169
170
template <typename IdxType, typename Idx64Type, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _pick_data(
    const Idx64Type *unique_frequency, const IdxType *unique_src_edges,
171
172
    const IdxType *unique_input_offsets, const IdxType *dst_data,
    const int64_t num_edges_per_node, const int64_t num_dst_nodes,
173
    const int64_t num_edges, const IdxType *unique_output_offsets,
174
175
176
177
178
179
180
181
182
183
    IdxType *output_src, IdxType *output_dst, IdxType *output_frequency) {
  int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
  int64_t last_idx = start_idx + TILE_SIZE;

  for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {
    if (idx < num_dst_nodes) {
      const int64_t dst_pos = (idx * num_edges_per_node);
      assert(dst_pos < num_edges);
      const IdxType dst = dst_data[dst_pos];
      const IdxType last_output_offset = unique_output_offsets[idx + 1];
184
185
      assert(
          (last_output_offset - unique_output_offsets[idx]) <=
186
          (unique_input_offsets[idx + 1] - unique_input_offsets[idx]));
187
188
189
      for (IdxType output_idx = unique_output_offsets[idx],
                   input_idx = unique_input_offsets[idx];
           output_idx < last_output_offset; ++output_idx, ++input_idx) {
190
191
        output_src[output_idx] = unique_src_edges[input_idx];
        output_dst[output_idx] = dst;
192
193
        output_frequency[output_idx] =
            static_cast<IdxType>(unique_frequency[input_idx]);
194
195
196
197
198
199
200
201
      }
    }
  }
}

}  // namespace

// return the old cnt of this edge
202
template <typename IdxType>
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
inline __device__ IdxType DeviceEdgeHashmap<IdxType>::InsertEdge(
    const IdxType &src, const IdxType &dst_idx) {
  IdxType start_off = dst_idx * _num_items_each_dst;
  IdxType pos = EdgeHash(src);
  IdxType delta = 1;
  IdxType old_cnt = static_cast<IdxType>(-1);
  while (true) {
    IdxType old_src = dgl::aten::cuda::AtomicCAS(
        &_edge_hashmap[start_off + pos].src, static_cast<IdxType>(-1), src);
    if (old_src == static_cast<IdxType>(-1) || old_src == src) {
      // first insert
      old_cnt = dgl::aten::cuda::AtomicAdd(
          &_edge_hashmap[start_off + pos].cnt, static_cast<IdxType>(1));
      if (old_src == static_cast<IdxType>(-1)) {
        assert(dst_idx < _num_dst);
218
219
        dgl::aten::cuda::AtomicAdd(
            &_dst_unique_edges[dst_idx], static_cast<IdxType>(1));
220
221
222
223
224
225
226
227
228
      }
      break;
    }
    pos = EdgeHash(pos + delta);
    delta += 1;
  }
  return old_cnt;
}

229
230
231
template <typename IdxType>
inline __device__ IdxType
DeviceEdgeHashmap<IdxType>::GetDstCount(const IdxType &dst_idx) {
232
233
234
  return _dst_unique_edges[dst_idx];
}

235
template <typename IdxType>
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
inline __device__ IdxType DeviceEdgeHashmap<IdxType>::GetEdgeCount(
    const IdxType &src, const IdxType &dst_idx) {
  IdxType start_off = dst_idx * _num_items_each_dst;
  IdxType pos = EdgeHash(src);
  IdxType delta = 1;
  while (_edge_hashmap[start_off + pos].src != src) {
    pos = EdgeHash(pos + delta);
    delta += 1;
  }
  return _edge_hashmap[start_off + pos].cnt;
}

template <typename IdxType>
FrequencyHashmap<IdxType>::FrequencyHashmap(
    int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx,
sangwzh's avatar
sangwzh committed
251
    hipStream_t stream, int64_t edge_table_scale) {
252
253
254
255
  _ctx = ctx;
  _stream = stream;
  num_items_each_dst = _table_size(num_items_each_dst, edge_table_scale);
  auto device = dgl::runtime::DeviceAPI::Get(_ctx);
256
  auto dst_unique_edges = static_cast<IdxType *>(
257
      device->AllocWorkspace(_ctx, (num_dst) * sizeof(IdxType)));
258
259
  auto edge_hashmap = static_cast<EdgeItem *>(device->AllocWorkspace(
      _ctx, (num_dst * num_items_each_dst) * sizeof(EdgeItem)));
260
  constexpr int BLOCK_SIZE = 256;
261
  constexpr int TILE_SIZE = BLOCK_SIZE * 8;
262
263
  dim3 block(BLOCK_SIZE);
  dim3 grid((num_dst * num_items_each_dst + TILE_SIZE - 1) / TILE_SIZE);
sangwzh's avatar
sangwzh committed
264
  CUDA_CALL(hipMemset(dst_unique_edges, 0, (num_dst) * sizeof(IdxType)));
265
266
267
  CUDA_KERNEL_CALL(
      (_init_edge_table<IdxType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0,
      _stream, edge_hashmap, (num_dst * num_items_each_dst));
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
  _device_edge_hashmap = new DeviceEdgeHashmap<IdxType>(
      num_dst, num_items_each_dst, dst_unique_edges, edge_hashmap);
  _dst_unique_edges = dst_unique_edges;
  _edge_hashmap = edge_hashmap;
}

template <typename IdxType>
FrequencyHashmap<IdxType>::~FrequencyHashmap() {
  auto device = dgl::runtime::DeviceAPI::Get(_ctx);
  delete _device_edge_hashmap;
  _device_edge_hashmap = nullptr;
  device->FreeWorkspace(_ctx, _dst_unique_edges);
  _dst_unique_edges = nullptr;
  device->FreeWorkspace(_ctx, _edge_hashmap);
  _edge_hashmap = nullptr;
}

template <typename IdxType>
std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(
287
    const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype,
288
289
290
291
292
293
    const int64_t num_edges, const int64_t num_edges_per_node,
    const int64_t num_pick) {
  using Idx64Type = int64_t;
  const int64_t num_dst_nodes = (num_edges / num_edges_per_node);
  constexpr int BLOCK_SIZE = 256;
  // XXX: a experienced value, best performance in GV100
294
  constexpr int TILE_SIZE = BLOCK_SIZE * 32;
295
296
297
298
299
300
  const dim3 block(BLOCK_SIZE);
  const dim3 edges_grid((num_edges + TILE_SIZE - 1) / TILE_SIZE);
  auto device = dgl::runtime::DeviceAPI::Get(_ctx);
  const IdxType num_edge_blocks = static_cast<IdxType>(edges_grid.x);
  IdxType num_unique_edges = 0;

301
302
303
  // to mark if this position of edges is the first inserting position for
  // _edge_hashmap
  bool *is_first_position = static_cast<bool *>(
304
      device->AllocWorkspace(_ctx, sizeof(bool) * (num_edges)));
sangwzh's avatar
sangwzh committed
305
  CUDA_CALL(hipMemset(is_first_position, 0, sizeof(bool) * (num_edges)));
306
  // double space to use ExclusiveSum
307
308
  auto edge_blocks_prefix_data = static_cast<IdxType *>(device->AllocWorkspace(
      _ctx, 2 * sizeof(IdxType) * (num_edge_blocks + 1)));
309
  IdxType *edge_blocks_prefix = edge_blocks_prefix_data;
310
311
  IdxType *edge_blocks_prefix_alternate =
      (edge_blocks_prefix_data + (num_edge_blocks + 1));
312
  // triple space to use ExclusiveSum and unique_output_offsets
313
  auto num_unique_each_node_data = static_cast<IdxType *>(
314
315
      device->AllocWorkspace(_ctx, 3 * sizeof(IdxType) * (num_dst_nodes + 1)));
  IdxType *num_unique_each_node = num_unique_each_node_data;
316
317
318
319
320
321
322
323
324
325
326
  IdxType *num_unique_each_node_alternate =
      (num_unique_each_node_data + (num_dst_nodes + 1));
  IdxType *unique_output_offsets =
      (num_unique_each_node_data + 2 * (num_dst_nodes + 1));

  // 1. Scan the all edges and count the unique edges and unique edges for each
  // dst node
  CUDA_KERNEL_CALL(
      (_count_frequency<IdxType, BLOCK_SIZE, TILE_SIZE>), edges_grid, block, 0,
      _stream, src_data, num_edges, num_edges_per_node, edge_blocks_prefix,
      is_first_position, *_device_edge_hashmap);
327
328
329
330
331

  // 2. Compact the unique edges frequency
  // 2.1 ExclusiveSum the edge_blocks_prefix
  void *d_temp_storage = nullptr;
  size_t temp_storage_bytes = 0;
sangwzh's avatar
sangwzh committed
332
  CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(
333
334
      d_temp_storage, temp_storage_bytes, edge_blocks_prefix,
      edge_blocks_prefix_alternate, num_edge_blocks + 1, _stream));
335
  d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
sangwzh's avatar
sangwzh committed
336
  CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(
337
338
      d_temp_storage, temp_storage_bytes, edge_blocks_prefix,
      edge_blocks_prefix_alternate, num_edge_blocks + 1, _stream));
339
340
  device->FreeWorkspace(_ctx, d_temp_storage);
  std::swap(edge_blocks_prefix, edge_blocks_prefix_alternate);
341
342
  device->CopyDataFromTo(
      &edge_blocks_prefix[num_edge_blocks], 0, &num_unique_edges, 0,
343
      sizeof(num_unique_edges), _ctx, DGLContext{kDGLCPU, 0}, dtype);
344
345
346
  device->StreamSync(_ctx, _stream);
  // 2.2 Allocate the data of unique edges and frequency
  // double space to use SegmentedRadixSort
347
  auto unique_src_edges_data = static_cast<IdxType *>(
348
349
      device->AllocWorkspace(_ctx, 2 * sizeof(IdxType) * (num_unique_edges)));
  IdxType *unique_src_edges = unique_src_edges_data;
350
351
  IdxType *unique_src_edges_alternate =
      unique_src_edges_data + num_unique_edges;
352
  // double space to use SegmentedRadixSort
353
  auto unique_frequency_data = static_cast<Idx64Type *>(
354
355
      device->AllocWorkspace(_ctx, 2 * sizeof(Idx64Type) * (num_unique_edges)));
  Idx64Type *unique_frequency = unique_frequency_data;
356
357
  Idx64Type *unique_frequency_alternate =
      unique_frequency_data + num_unique_edges;
358
  // 2.3 Compact the unique edges and their frequency
359
360
361
362
363
364
  CUDA_KERNEL_CALL(
      (_compact_frequency<IdxType, Idx64Type, BLOCK_SIZE, TILE_SIZE>),
      edges_grid, block, 0, _stream, src_data, dst_data, num_edges,
      num_edges_per_node, edge_blocks_prefix, is_first_position,
      num_unique_each_node, unique_src_edges, unique_frequency,
      *_device_edge_hashmap);
365
366
367
368
369

  // 3. SegmentedRadixSort the unique edges and unique_frequency
  // 3.1 ExclusiveSum the num_unique_each_node
  d_temp_storage = nullptr;
  temp_storage_bytes = 0;
sangwzh's avatar
sangwzh committed
370
  CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(
371
372
      d_temp_storage, temp_storage_bytes, num_unique_each_node,
      num_unique_each_node_alternate, num_dst_nodes + 1, _stream));
373
  d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
sangwzh's avatar
sangwzh committed
374
  CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(
375
376
      d_temp_storage, temp_storage_bytes, num_unique_each_node,
      num_unique_each_node_alternate, num_dst_nodes + 1, _stream));
377
378
379
  device->FreeWorkspace(_ctx, d_temp_storage);
  // 3.2 SegmentedRadixSort the unique_src_edges and unique_frequency
  // Create a set of DoubleBuffers to wrap pairs of device pointers
sangwzh's avatar
sangwzh committed
380
  hipcub::DoubleBuffer<Idx64Type> d_unique_frequency(
381
      unique_frequency, unique_frequency_alternate);
sangwzh's avatar
sangwzh committed
382
  hipcub::DoubleBuffer<IdxType> d_unique_src_edges(
383
      unique_src_edges, unique_src_edges_alternate);
384
385
386
387
388
389
  // Determine temporary device storage requirements
  d_temp_storage = nullptr;
  temp_storage_bytes = 0;
  // the DeviceRadixSort is faster than DeviceSegmentedRadixSort,
  // especially when num_dst_nodes is large (about ~10000)
  if (dtype.bits == 32) {
sangwzh's avatar
sangwzh committed
390
    CUDA_CALL(hipcub::DeviceRadixSort::SortPairsDescending(
391
392
393
        d_temp_storage, temp_storage_bytes, d_unique_frequency,
        d_unique_src_edges, num_unique_edges, 0, sizeof(Idx64Type) * 8,
        _stream));
394
  } else {
sangwzh's avatar
sangwzh committed
395
    CUDA_CALL(hipcub::DeviceSegmentedRadixSort::SortPairsDescending(
396
397
398
399
        d_temp_storage, temp_storage_bytes, d_unique_frequency,
        d_unique_src_edges, num_unique_edges, num_dst_nodes,
        num_unique_each_node_alternate, num_unique_each_node_alternate + 1, 0,
        sizeof(Idx64Type) * 8, _stream));
400
401
402
  }
  d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
  if (dtype.bits == 32) {
sangwzh's avatar
sangwzh committed
403
    CUDA_CALL(hipcub::DeviceRadixSort::SortPairsDescending(
404
405
406
        d_temp_storage, temp_storage_bytes, d_unique_frequency,
        d_unique_src_edges, num_unique_edges, 0, sizeof(Idx64Type) * 8,
        _stream));
407
  } else {
sangwzh's avatar
sangwzh committed
408
    CUDA_CALL(hipcub::DeviceSegmentedRadixSort::SortPairsDescending(
409
410
411
412
        d_temp_storage, temp_storage_bytes, d_unique_frequency,
        d_unique_src_edges, num_unique_edges, num_dst_nodes,
        num_unique_each_node_alternate, num_unique_each_node_alternate + 1, 0,
        sizeof(Idx64Type) * 8, _stream));
413
414
415
416
417
  }
  device->FreeWorkspace(_ctx, d_temp_storage);

  // 4. Get the final pick number for each dst node
  // 4.1 Reset the min(num_pick, num_unique_each_node) to num_unique_each_node
418
  constexpr int NODE_TILE_SIZE = BLOCK_SIZE * 2;
419
  const dim3 nodes_grid((num_dst_nodes + NODE_TILE_SIZE - 1) / NODE_TILE_SIZE);
420
421
422
  CUDA_KERNEL_CALL(
      (_get_pick_num<IdxType, BLOCK_SIZE, NODE_TILE_SIZE>), nodes_grid, block,
      0, _stream, num_unique_each_node, num_pick, num_dst_nodes);
423
424
425
426
  // 4.2 ExclusiveSum the new num_unique_each_node as unique_output_offsets
  // use unique_output_offsets;
  d_temp_storage = nullptr;
  temp_storage_bytes = 0;
sangwzh's avatar
sangwzh committed
427
  CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(
428
429
      d_temp_storage, temp_storage_bytes, num_unique_each_node,
      unique_output_offsets, num_dst_nodes + 1, _stream));
430
  d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
sangwzh's avatar
sangwzh committed
431
  CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(
432
433
      d_temp_storage, temp_storage_bytes, num_unique_each_node,
      unique_output_offsets, num_dst_nodes + 1, _stream));
434
435
436
437
  device->FreeWorkspace(_ctx, d_temp_storage);

  // 5. Pick the data to result
  IdxType num_output = 0;
438
439
  device->CopyDataFromTo(
      &unique_output_offsets[num_dst_nodes], 0, &num_output, 0,
440
      sizeof(num_output), _ctx, DGLContext{kDGLCPU, 0}, dtype);
441
442
  device->StreamSync(_ctx, _stream);

443
444
445
446
447
448
449
450
451
452
453
  IdArray res_src =
      IdArray::Empty({static_cast<int64_t>(num_output)}, dtype, _ctx);
  IdArray res_dst =
      IdArray::Empty({static_cast<int64_t>(num_output)}, dtype, _ctx);
  IdArray res_cnt =
      IdArray::Empty({static_cast<int64_t>(num_output)}, dtype, _ctx);
  CUDA_KERNEL_CALL(
      (_pick_data<IdxType, Idx64Type, BLOCK_SIZE, NODE_TILE_SIZE>), nodes_grid,
      block, 0, _stream, d_unique_frequency.Current(),
      d_unique_src_edges.Current(), num_unique_each_node_alternate, dst_data,
      num_edges_per_node, num_dst_nodes, num_edges, unique_output_offsets,
454
455
456
457
458
459
460
461
462
463
464
      res_src.Ptr<IdxType>(), res_dst.Ptr<IdxType>(), res_cnt.Ptr<IdxType>());

  device->FreeWorkspace(_ctx, is_first_position);
  device->FreeWorkspace(_ctx, edge_blocks_prefix_data);
  device->FreeWorkspace(_ctx, num_unique_each_node_data);
  device->FreeWorkspace(_ctx, unique_src_edges_data);
  device->FreeWorkspace(_ctx, unique_frequency_data);

  return std::make_tuple(res_src, res_dst, res_cnt);
}

465
template class FrequencyHashmap<int64_t>;
466

467
template class FrequencyHashmap<int32_t>;
468
469
470
471
472
473

};  // namespace impl

};  // namespace sampling

};  // namespace dgl