csr_get_data.cu 3.65 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file array/cuda/csr_get_data.cu
 * @brief Retrieve entries of a CSR matrix
5
6
 */
#include <dgl/array.h>
7

8
#include <numeric>
9
10
11
#include <unordered_set>
#include <vector>

12
13
14
15
16
17
18
19
20
21
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

22
template <DGLDeviceType XPU, typename IdType, typename DType>
23
NDArray CSRGetData(
24
25
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, DType filler) {
26
27
28
29
  const int64_t rowlen = rows->shape[0];
  const int64_t collen = cols->shape[0];

  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
30
      << "Invalid row and col id array.";
31
32
33
34
35
36

  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;

  const int64_t rstlen = std::max(rowlen, collen);
  IdArray rst = NDArray::Empty({rstlen}, weights->dtype, rows->ctx);
37
  if (rstlen == 0) return rst;
38

39
  cudaStream_t stream = runtime::getCurrentCUDAStream();
40
41
42
  const int nt = cuda::FindNumThreads(rstlen);
  const int nb = (rstlen + nt - 1) / nt;
  if (return_eids)
43
44
    BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype)
        << "DType does not match row's dtype.";
45

46
47
48
49
50
51
52
  const IdType* indptr_data =
      static_cast<IdType*>(cuda::GetDevicePointer(csr.indptr));
  const IdType* indices_data =
      static_cast<IdType*>(cuda::GetDevicePointer(csr.indices));
  const IdType* data_data =
      CSRHasData(csr) ? static_cast<IdType*>(cuda::GetDevicePointer(csr.data))
                      : nullptr;
53

54
  // TODO(minjie): use binary search for sorted csr
55
56
57
58
59
  CUDA_KERNEL_CALL(
      cuda::_LinearSearchKernel, nb, nt, 0, stream, indptr_data, indices_data,
      data_data, rows.Ptr<IdType>(), cols.Ptr<IdType>(), row_stride, col_stride,
      rstlen, return_eids ? nullptr : weights.Ptr<DType>(), filler,
      rst.Ptr<DType>());
60
61
62
  return rst;
}

63
template NDArray CSRGetData<kDGLCUDA, int32_t, __half>(
64
65
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, __half filler);
66
template NDArray CSRGetData<kDGLCUDA, int64_t, __half>(
67
68
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, __half filler);
69
70
71
72
73
74
75
76
#if BF16_ENABLED
template NDArray CSRGetData<kDGLCUDA, int32_t, __nv_bfloat16>(
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, __nv_bfloat16 filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, __nv_bfloat16>(
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, __nv_bfloat16 filler);
#endif  // BF16_ENABLED
77
template NDArray CSRGetData<kDGLCUDA, int32_t, float>(
78
79
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, float filler);
80
template NDArray CSRGetData<kDGLCUDA, int64_t, float>(
81
82
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, float filler);
83
template NDArray CSRGetData<kDGLCUDA, int32_t, double>(
84
85
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, double filler);
86
template NDArray CSRGetData<kDGLCUDA, int64_t, double>(
87
88
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, double filler);
89
90

// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
91
template NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>(
92
93
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, int32_t filler);
94
template NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>(
95
96
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, int64_t filler);
97
98
99
100

}  // namespace impl
}  // namespace aten
}  // namespace dgl