csr_get_data.hip 3.79 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2021 by Contributors
5
6
 * @file array/cuda/csr_get_data.cu
 * @brief Retrieve entries of a CSR matrix
7
8
 */
#include <dgl/array.h>
sangwzh's avatar
sangwzh committed
9
10
#include "../../../include/dgl/array.h"

11

12
#include <numeric>
13
14
15
#include <unordered_set>
#include <vector>

16
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
17
#include "utils.h"
18
19
20
21
22
23
24
25

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

26
template <DGLDeviceType XPU, typename IdType, typename DType>
27
NDArray CSRGetData(
28
29
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, DType filler) {
30
31
32
33
  const int64_t rowlen = rows->shape[0];
  const int64_t collen = cols->shape[0];

  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
34
      << "Invalid row and col id array.";
35
36
37
38

  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;

sangwzh's avatar
sangwzh committed
39
  const int64_t rstlen = ::max(rowlen, collen);
40
  IdArray rst = NDArray::Empty({rstlen}, weights->dtype, rows->ctx);
41
  if (rstlen == 0) return rst;
42

sangwzh's avatar
sangwzh committed
43
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
44
45
46
  const int nt = cuda::FindNumThreads(rstlen);
  const int nb = (rstlen + nt - 1) / nt;
  if (return_eids)
47
48
    BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype)
        << "DType does not match row's dtype.";
49

50
51
52
53
54
55
56
  const IdType* indptr_data =
      static_cast<IdType*>(cuda::GetDevicePointer(csr.indptr));
  const IdType* indices_data =
      static_cast<IdType*>(cuda::GetDevicePointer(csr.indices));
  const IdType* data_data =
      CSRHasData(csr) ? static_cast<IdType*>(cuda::GetDevicePointer(csr.data))
                      : nullptr;
57

58
  // TODO(minjie): use binary search for sorted csr
59
60
61
62
63
  CUDA_KERNEL_CALL(
      cuda::_LinearSearchKernel, nb, nt, 0, stream, indptr_data, indices_data,
      data_data, rows.Ptr<IdType>(), cols.Ptr<IdType>(), row_stride, col_stride,
      rstlen, return_eids ? nullptr : weights.Ptr<DType>(), filler,
      rst.Ptr<DType>());
64
65
66
  return rst;
}

67
template NDArray CSRGetData<kDGLCUDA, int32_t, __half>(
68
69
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, __half filler);
70
template NDArray CSRGetData<kDGLCUDA, int64_t, __half>(
71
72
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, __half filler);
73
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
74
template NDArray CSRGetData<kDGLCUDA, int32_t, __hip_bfloat16>(
75
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
sangwzh's avatar
sangwzh committed
76
77
    NDArray weights, __hip_bfloat16 filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, __hip_bfloat16>(
78
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
sangwzh's avatar
sangwzh committed
79
    NDArray weights, __hip_bfloat16 filler);
80
#endif  // BF16_ENABLED
81
template NDArray CSRGetData<kDGLCUDA, int32_t, float>(
82
83
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, float filler);
84
template NDArray CSRGetData<kDGLCUDA, int64_t, float>(
85
86
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, float filler);
87
template NDArray CSRGetData<kDGLCUDA, int32_t, double>(
88
89
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, double filler);
90
template NDArray CSRGetData<kDGLCUDA, int64_t, double>(
91
92
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, double filler);
93
94

// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
95
template NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>(
96
97
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, int32_t filler);
98
template NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>(
99
100
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, int64_t filler);
101
102
103
104

}  // namespace impl
}  // namespace aten
}  // namespace dgl