"llama/vscode:/vscode.git/clone" did not exist on "9cfbffafc55bb07f7d627a6701df2e9c8a520f5f"
csr_get_data.cu 3.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*!
 *  Copyright (c) 2021 by Contributors
 * \file array/cuda/csr_get_data.cu
 * \brief Retrieve entries of a CSR matrix
 */
#include <dgl/array.h>
#include <vector>
#include <unordered_set>
#include <numeric>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

20
template <DGLDeviceType XPU, typename IdType, typename DType>
21
NDArray CSRGetData(
22
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
23
24
25
26
27
28
29
30
31
32
33
34
35
36
  const int64_t rowlen = rows->shape[0];
  const int64_t collen = cols->shape[0];

  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
    << "Invalid row and col id array.";

  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;

  const int64_t rstlen = std::max(rowlen, collen);
  IdArray rst = NDArray::Empty({rstlen}, weights->dtype, rows->ctx);
  if (rstlen == 0)
    return rst;

37
  cudaStream_t stream = runtime::getCurrentCUDAStream();
38
39
40
  const int nt = cuda::FindNumThreads(rstlen);
  const int nb = (rstlen + nt - 1) / nt;
  if (return_eids)
41
    BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) <<
42
43
      "DType does not match row's dtype.";

44
45
46
47
48
49
50
51
52
53
54
55
56
57
  const IdType* indptr_data = csr.indptr.Ptr<IdType>();
  const IdType* indices_data = csr.indices.Ptr<IdType>();
  const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr;
  if (csr.is_pinned) {
    CUDA_CALL(cudaHostGetDevicePointer(
        &indptr_data, csr.indptr.Ptr<IdType>(), 0));
    CUDA_CALL(cudaHostGetDevicePointer(
        &indices_data, csr.indices.Ptr<IdType>(), 0));
    if (CSRHasData(csr)) {
      CUDA_CALL(cudaHostGetDevicePointer(
          &data_data, csr.data.Ptr<IdType>(), 0));
    }
  }

58
59
  // TODO(minjie): use binary search for sorted csr
  CUDA_KERNEL_CALL(cuda::_LinearSearchKernel,
60
      nb, nt, 0, stream,
61
      indptr_data, indices_data, data_data,
62
63
64
65
66
67
      rows.Ptr<IdType>(), cols.Ptr<IdType>(),
      row_stride, col_stride, rstlen,
      return_eids ? nullptr : weights.Ptr<DType>(), filler, rst.Ptr<DType>());
  return rst;
}

68
#ifdef USE_FP16
69
template NDArray CSRGetData<kDGLCUDA, int32_t, __half>(
70
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
71
template NDArray CSRGetData<kDGLCUDA, int64_t, __half>(
72
73
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
#endif
74
template NDArray CSRGetData<kDGLCUDA, int32_t, float>(
75
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
76
template NDArray CSRGetData<kDGLCUDA, int64_t, float>(
77
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
78
template NDArray CSRGetData<kDGLCUDA, int32_t, double>(
79
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
80
template NDArray CSRGetData<kDGLCUDA, int64_t, double>(
81
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
82
83

// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
84
template NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>(
85
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
86
template NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>(
87
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
88
89
90
91

}  // namespace impl
}  // namespace aten
}  // namespace dgl