cusparse_dispatcher.cuh 7.04 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
 * @file array/cuda/dispatcher.cuh
5
6
 * @brief Templates to dispatch into different cuSPARSE routines based on the
 * type argument.
7
8
9
10
 */
#ifndef DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_
#define DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_

sangwzh's avatar
sangwzh committed
11
#include <hipsparse/hipsparse.h>
12
#include <dgl/runtime/c_runtime_api.h>
13

14
#include "bf16.cuh"
15
#include "fp16.cuh"
16
17
18
19

namespace dgl {
namespace aten {

20
/** @brief cusparseXcsrgemm dispatcher */
21
22
23
template <typename DType>
struct CSRGEMM {
  template <typename... Args>
sangwzh's avatar
sangwzh committed
24
  static inline hipsparseStatus_t bufferSizeExt(Args... args) {
25
    BUG_IF_FAIL(false) << "This piece of code should not be reached.";
sangwzh's avatar
sangwzh committed
26
    return static_cast<hipsparseStatus_t>(0);
27
28
29
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
30
31
  static inline hipsparseStatus_t nnz(Args... args) {
    return hipsparseXcsrgemm2Nnz(args...);
32
33
34
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
35
  static inline hipsparseStatus_t compute(Args... args) {
36
    BUG_IF_FAIL(false) << "This piece of code should not be reached.";
sangwzh's avatar
sangwzh committed
37
    return static_cast<hipsparseStatus_t>(0);
38
39
40
  }
};

41
42
43
template <>
struct CSRGEMM<__half> {
  template <typename... Args>
sangwzh's avatar
sangwzh committed
44
  static inline hipsparseStatus_t bufferSizeExt(Args... args) {
45
46
    // TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a
    // different implementation would be required.
47
    LOG(FATAL) << "CSRGEMM::bufferSizeExt does not support dtype half (FP16).";
sangwzh's avatar
sangwzh committed
48
    return static_cast<hipsparseStatus_t>(0);
49
50
51
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
52
53
  static inline hipsparseStatus_t nnz(Args... args) {
    return hipsparseXcsrgemm2Nnz(args...);
54
55
56
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
57
  static inline hipsparseStatus_t compute(Args... args) {
58
59
60
    // TODO(ndickson): There is no cusparseHcsrgemm2, so a different
    // implementation would be required.
    LOG(FATAL) << "CSRGEMM::compute does not support dtype half (FP16).";
sangwzh's avatar
sangwzh committed
61
    return static_cast<hipsparseStatus_t>(0);
62
63
  }
};
64
65
66

#if BF16_ENABLED
template <>
sangwzh's avatar
sangwzh committed
67
struct CSRGEMM<__hip_bfloat16> {
68
  template <typename... Args>
sangwzh's avatar
sangwzh committed
69
  static inline hipsparseStatus_t bufferSizeExt(Args... args) {
70
71
72
73
    // TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a
    // different implementation would be required.
    LOG(FATAL)
        << "CSRGEMM::bufferSizeExt does not support dtype bfloat16 (BF16).";
sangwzh's avatar
sangwzh committed
74
    return static_cast<hipsparseStatus_t>(0);
75
76
77
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
78
79
  static inline hipsparseStatus_t nnz(Args... args) {
    return hipsparseXcsrgemm2Nnz(args...);
80
81
82
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
83
  static inline hipsparseStatus_t compute(Args... args) {
84
85
86
    // TODO(ndickson): There is no cusparseHcsrgemm2, so a different
    // implementation would be required.
    LOG(FATAL) << "CSRGEMM::compute does not support dtype bfloat16 (BF16).";
sangwzh's avatar
sangwzh committed
87
    return static_cast<hipsparseStatus_t>(0);
88
89
90
  }
};
#endif  // BF16_ENABLED
91

92
93
94
template <>
struct CSRGEMM<float> {
  template <typename... Args>
sangwzh's avatar
sangwzh committed
95
96
  static inline hipsparseStatus_t bufferSizeExt(Args... args) {
    return hipsparseScsrgemm2_bufferSizeExt(args...);
97
98
99
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
100
101
  static inline hipsparseStatus_t nnz(Args... args) {
    return hipsparseXcsrgemm2Nnz(args...);
102
103
104
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
105
106
  static inline hipsparseStatus_t compute(Args... args) {
    return hipsparseScsrgemm2(args...);
107
108
109
110
111
112
  }
};

template <>
struct CSRGEMM<double> {
  template <typename... Args>
sangwzh's avatar
sangwzh committed
113
114
  static inline hipsparseStatus_t bufferSizeExt(Args... args) {
    return hipsparseDcsrgemm2_bufferSizeExt(args...);
115
116
117
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
118
119
  static inline hipsparseStatus_t nnz(Args... args) {
    return hipsparseXcsrgemm2Nnz(args...);
120
121
122
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
123
124
  static inline hipsparseStatus_t compute(Args... args) {
    return hipsparseDcsrgemm2(args...);
125
126
127
  }
};

128
/** @brief cusparseXcsrgeam dispatcher */
129
130
131
template <typename DType>
struct CSRGEAM {
  template <typename... Args>
sangwzh's avatar
sangwzh committed
132
  static inline hipsparseStatus_t bufferSizeExt(Args... args) {
133
    BUG_IF_FAIL(false) << "This piece of code should not be reached.";
sangwzh's avatar
sangwzh committed
134
    return static_cast<hipsparseStatus_t>(0);
135
136
137
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
138
139
  static inline hipsparseStatus_t nnz(Args... args) {
    return hipsparseXcsrgeam2Nnz(args...);
140
141
142
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
143
  static inline hipsparseStatus_t compute(Args... args) {
144
    BUG_IF_FAIL(false) << "This piece of code should not be reached.";
sangwzh's avatar
sangwzh committed
145
    return static_cast<hipsparseStatus_t>(0);
146
147
148
  }
};

149
150
151
template <>
struct CSRGEAM<__half> {
  template <typename... Args>
sangwzh's avatar
sangwzh committed
152
  static inline hipsparseStatus_t bufferSizeExt(Args... args) {
153
154
    // TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a
    // different implementation would be required.
155
    LOG(FATAL) << "CSRGEAM::bufferSizeExt does not support dtype half (FP16).";
sangwzh's avatar
sangwzh committed
156
    return static_cast<hipsparseStatus_t>(0);
157
158
159
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
160
161
  static inline hipsparseStatus_t nnz(Args... args) {
    return hipsparseXcsrgeam2Nnz(args...);
162
163
164
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
165
  static inline hipsparseStatus_t compute(Args... args) {
166
167
168
    // TODO(ndickson): There is no cusparseHcsrgeam2, so a different
    // implementation would be required.
    LOG(FATAL) << "CSRGEAM::compute does not support dtype half (FP16).";
sangwzh's avatar
sangwzh committed
169
    return static_cast<hipsparseStatus_t>(0);
170
171
  }
};
172
173
174

#if BF16_ENABLED
template <>
sangwzh's avatar
sangwzh committed
175
struct CSRGEAM<__hip_bfloat16> {
176
  template <typename... Args>
sangwzh's avatar
sangwzh committed
177
  static inline hipsparseStatus_t bufferSizeExt(Args... args) {
178
179
180
181
    // TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a
    // different implementation would be required.
    LOG(FATAL)
        << "CSRGEAM::bufferSizeExt does not support dtype bfloat16 (BF16).";
sangwzh's avatar
sangwzh committed
182
    return static_cast<hipsparseStatus_t>(0);
183
184
185
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
186
187
  static inline hipsparseStatus_t nnz(Args... args) {
    return hipsparseXcsrgeam2Nnz(args...);
188
189
190
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
191
  static inline hipsparseStatus_t compute(Args... args) {
192
193
194
    // TODO(ndickson): There is no cusparseHcsrgeam2, so a different
    // implementation would be required.
    LOG(FATAL) << "CSRGEAM::compute does not support dtype bfloat16 (BF16).";
sangwzh's avatar
sangwzh committed
195
    return static_cast<hipsparseStatus_t>(0);
196
197
198
  }
};
#endif  // BF16_ENABLED
199

200
201
202
template <>
struct CSRGEAM<float> {
  template <typename... Args>
sangwzh's avatar
sangwzh committed
203
204
  static inline hipsparseStatus_t bufferSizeExt(Args... args) {
    return hipsparseScsrgeam2_bufferSizeExt(args...);
205
206
207
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
208
209
  static inline hipsparseStatus_t nnz(Args... args) {
    return hipsparseXcsrgeam2Nnz(args...);
210
211
212
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
213
214
  static inline hipsparseStatus_t compute(Args... args) {
    return hipsparseScsrgeam2(args...);
215
216
217
218
219
220
  }
};

template <>
struct CSRGEAM<double> {
  template <typename... Args>
sangwzh's avatar
sangwzh committed
221
222
  static inline hipsparseStatus_t bufferSizeExt(Args... args) {
    return hipsparseDcsrgeam2_bufferSizeExt(args...);
223
224
225
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
226
227
  static inline hipsparseStatus_t nnz(Args... args) {
    return hipsparseXcsrgeam2Nnz(args...);
228
229
230
  }

  template <typename... Args>
sangwzh's avatar
sangwzh committed
231
232
  static inline hipsparseStatus_t compute(Args... args) {
    return hipsparseDcsrgeam2(args...);
233
234
235
236
237
238
239
  }
};

};  // namespace aten
};  // namespace dgl

#endif  // DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_