array_index_select_uvm.cuh 1.48 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2021 by Contributors
 * \file array/cpu/array_index_select_uvm.cuh
 * \brief Array index select GPU kernel implementation
 */

7
8
#ifndef DGL_ARRAY_CUDA_UVM_ARRAY_INDEX_SELECT_UVM_CUH_
#define DGL_ARRAY_CUDA_UVM_ARRAY_INDEX_SELECT_UVM_CUH_
9

10
11
#define CACHE_LINE_SIZE 128

12
13
14
15
namespace dgl {
namespace aten {
namespace impl {

16
17
18
19
/*  This is a cross-device access version of IndexSelectMultiKernel.
*   Since the memory access over PCIe is more sensitive to the
*   data access aligment (cacheline), we need a separate version here.
*/
20
template <typename DType, typename IdType>
21
__global__ void IndexSelectMultiKernelAligned(
22
23
24
25
        const DType* const array,
        const int64_t num_feat,
        const IdType* const index,
        const int64_t length,
26
        const int64_t arr_len,
27
28
29
30
31
32
33
34
        DType* const out) {
  int64_t out_row = blockIdx.x*blockDim.y+threadIdx.y;

  const int64_t stride = blockDim.y*gridDim.x;

  while (out_row < length) {
    int64_t col = threadIdx.x;
    const int64_t in_row = index[out_row];
35
    assert(in_row >= 0 && in_row < arr_len);
36
37
38
    const int64_t idx_offset =
      ((uint64_t)(&array[in_row*num_feat]) % CACHE_LINE_SIZE) / sizeof(DType);
    col = col - idx_offset;
39
    while (col < num_feat) {
40
41
      if (col >= 0)
        out[out_row*num_feat+col] = array[in_row*num_feat+col];
42
43
44
45
46
47
48
49
50
51
      col += blockDim.x;
    }
    out_row += stride;
  }
}

}  // namespace impl
}  // namespace aten
}  // namespace dgl

52
#endif  // DGL_ARRAY_CUDA_UVM_ARRAY_INDEX_SELECT_UVM_CUH_