"script/build_dgl.sh" did not exist on "2bbca12ad53d7b2b1fdd4e527b2219a1f1a59f0e"
csr_transpose.cc 3.49 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cuda/csr_transpose.cc
 * @brief CSR transpose (convert to CSC)
7
8
 */
#include <dgl/array.h>
9

10
11
12
13
14
15
16
17
18
#include "../../runtime/cuda/cuda_common.h"

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

19
template <DGLDeviceType XPU, typename IdType>
20
CSRMatrix CSRTranspose(CSRMatrix csr) {
21
22
23
24
25
  LOG(FATAL) << "Unreachable codes";
  return {};
}

template <>
26
CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
27
#if CUDART_VERSION < 12000
28
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
sangwzh's avatar
sangwzh committed
29
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
30
31
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
sangwzh's avatar
sangwzh committed
32
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
33
  }
sangwzh's avatar
sangwzh committed
34
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream));
35
36
37
38
39

  NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data;
  const int64_t nnz = indices->shape[0];
  const auto& ctx = indptr->ctx;
  const auto bits = indptr->dtype.bits;
40
  if (aten::IsNullArray(data)) data = aten::Range(0, nnz, bits, ctx);
41
42
43
44
  const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
  const int32_t* indices_ptr = static_cast<int32_t*>(indices->data);
  const void* data_ptr = data->data;

45
46
  // (BarclayII) csr2csc doesn't seem to clear the content of cscColPtr if nnz
  // == 0. We need to do it ourselves.
47
  NDArray t_indptr = aten::Full(0, csr.num_cols + 1, bits, ctx);
48
49
50
51
52
53
  NDArray t_indices = aten::NewIdArray(nnz, ctx, bits);
  NDArray t_data = aten::NewIdArray(nnz, ctx, bits);
  int32_t* t_indptr_ptr = static_cast<int32_t*>(t_indptr->data);
  int32_t* t_indices_ptr = static_cast<int32_t*>(t_indices->data);
  void* t_data_ptr = t_data->data;

sangwzh's avatar
sangwzh committed
54
#if DTKRT_VERSION >= 10010
55
56
57
  auto device = runtime::DeviceAPI::Get(csr.indptr->ctx);
  // workspace
  size_t workspace_size;
sangwzh's avatar
sangwzh committed
58
  CUSPARSE_CALL(hipsparseCsr2cscEx2_bufferSize(
59
60
      thr_entry->cusparse_handle, csr.num_rows, csr.num_cols, nnz, data_ptr,
      indptr_ptr, indices_ptr, t_data_ptr, t_indptr_ptr, t_indices_ptr,
sangwzh's avatar
sangwzh committed
61
62
      HIP_R_32F, HIPSPARSE_ACTION_NUMERIC, HIPSPARSE_INDEX_BASE_ZERO,
      HIPSPARSE_CSR2CSC_ALG1,  // see cusparse doc for reference
63
64
      &workspace_size));
  void* workspace = device->AllocWorkspace(ctx, workspace_size);
sangwzh's avatar
sangwzh committed
65
  CUSPARSE_CALL(hipsparseCsr2cscEx2(
66
67
      thr_entry->cusparse_handle, csr.num_rows, csr.num_cols, nnz, data_ptr,
      indptr_ptr, indices_ptr, t_data_ptr, t_indptr_ptr, t_indices_ptr,
sangwzh's avatar
sangwzh committed
68
69
      HIP_R_32F, HIPSPARSE_ACTION_NUMERIC, HIPSPARSE_INDEX_BASE_ZERO,
      HIPSPARSE_CSR2CSC_ALG1,  // see cusparse doc for reference
70
71
72
      workspace));
  device->FreeWorkspace(ctx, workspace);
#else
sangwzh's avatar
sangwzh committed
73
  CUSPARSE_CALL(hipsparseScsr2csc(
74
      thr_entry->cusparse_handle, csr.num_rows, csr.num_cols, nnz,
75
76
      static_cast<const float*>(data_ptr), indptr_ptr, indices_ptr,
      static_cast<float*>(t_data_ptr), t_indices_ptr, t_indptr_ptr,
sangwzh's avatar
sangwzh committed
77
      HIPSPARSE_ACTION_NUMERIC, HIPSPARSE_INDEX_BASE_ZERO));
78
79
#endif

80
81
  return CSRMatrix(
      csr.num_cols, csr.num_rows, t_indptr, t_indices, t_data, false);
82
83
84
#else
  return COOToCSR(COOTranspose(CSRToCOO(csr, false)));
#endif
85
86
}

87
template <>
88
CSRMatrix CSRTranspose<kDGLCUDA, int64_t>(CSRMatrix csr) {
89
90
91
  return COOToCSR(COOTranspose(CSRToCOO(csr, false)));
}

92
93
template CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDGLCUDA, int64_t>(CSRMatrix csr);
94
95
96
97

}  // namespace impl
}  // namespace aten
}  // namespace dgl