gather_mm.hip 19.7 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
Israt Nisa's avatar
Israt Nisa committed
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cuda/gather_mm.cu
 * @brief GatherMM C APIs and definitions.
Israt Nisa's avatar
Israt Nisa committed
7
8
 */
#include <dgl/array.h>
9

Israt Nisa's avatar
Israt Nisa committed
10
#include <algorithm>  // std::swap
11

sangwzh's avatar
sangwzh committed
12
13
14
#include "atomic.cuh"
#include "functor.cuh"
#include "utils.h"
Israt Nisa's avatar
Israt Nisa committed
15
16
17
18
19
20
21

namespace dgl {
using namespace cuda;
namespace aten {

namespace {

22
23
/** @brief Call cuBLAS GEMM API for dense matmul operation for float and double.
 */
Israt Nisa's avatar
Israt Nisa committed
24
template <typename DType>
sangwzh's avatar
sangwzh committed
25
26
hipblasStatus_t cublasGemm(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
27
28
    int m, int n, int k, const DType* alpha, const DType* A, int lda,
    const DType* B, int ldb, const DType* beta, DType* C, int ldc) {
Israt Nisa's avatar
Israt Nisa committed
29
  LOG(INFO) << "Not supported dtype";
sangwzh's avatar
sangwzh committed
30
  return HIPBLAS_STATUS_EXECUTION_FAILED;
Israt Nisa's avatar
Israt Nisa committed
31
32
}

33
template <>
sangwzh's avatar
sangwzh committed
34
35
hipblasStatus_t cublasGemm<__half>(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
36
37
    int m, int n, int k, const __half* alpha, const __half* A, int lda,
    const __half* B, int ldb, const __half* beta, __half* C, int ldc) {
sangwzh's avatar
sangwzh committed
38
39
  return hipblasHgemm(
      handle, transa, transb, m, n, k, (hipblasHalf*)alpha, (hipblasHalf*)A, lda, (hipblasHalf*)B, ldb, (hipblasHalf*)beta, (hipblasHalf*)C, ldc);
40
}
41

sangwzh's avatar
sangwzh committed
42
43
44
45
46
47
48
49
50
// template <>
// hipblasStatus_t cublasGemm<__half>(
//     hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
//     int m, int n, int k, const __half* alpha, const __half* A, int lda,
//     const __half* B, int ldb, const __half* beta, __half* C, int ldc) {
//   return hipblasHgemm(
//       handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
// }

51
52
#if BF16_ENABLED
template <>
sangwzh's avatar
sangwzh committed
53
54
55
56
57
hipblasStatus_t cublasGemm<__hip_bfloat16>(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
    int m, int n, int k, const __hip_bfloat16* alpha, const __hip_bfloat16* A,
    int lda, const __hip_bfloat16* B, int ldb, const __hip_bfloat16* beta,
    __hip_bfloat16* C, int ldc) {
58
59
  float alpha_float = __bfloat162float(*alpha);
  float beta_float = __bfloat162float(*beta);
sangwzh's avatar
sangwzh committed
60
61
62
63
  return hipblasGemmEx(
      handle, transa, transb, m, n, k, &alpha_float, A, HIPBLAS_R_16B, lda, B,
      HIPBLAS_R_16B, ldb, &beta_float, C, HIPBLAS_R_16B, ldc, HIPBLAS_R_32F,
      HIPBLAS_GEMM_DEFAULT);
64
65
}
#endif  // BF16_ENABLED
66

Israt Nisa's avatar
Israt Nisa committed
67
template <>
sangwzh's avatar
sangwzh committed
68
69
hipblasStatus_t cublasGemm<float>(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
70
71
    int m, int n, int k, const float* alpha, const float* A, int lda,
    const float* B, int ldb, const float* beta, float* C, int ldc) {
sangwzh's avatar
sangwzh committed
72
  return hipblasSgemm(
73
      handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
Israt Nisa's avatar
Israt Nisa committed
74
75
76
}

template <>
sangwzh's avatar
sangwzh committed
77
78
hipblasStatus_t cublasGemm<double>(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
79
80
    int m, int n, int k, const double* alpha, const double* A, int lda,
    const double* B, int ldb, const double* beta, double* C, int ldc) {
sangwzh's avatar
sangwzh committed
81
  return hipblasDgemm(
82
      handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
Israt Nisa's avatar
Israt Nisa committed
83
84
85
86
87
88
}

}  // namespace

namespace cuda {

89
/**
90
91
92
93
94
 * @note Each row of A multiplies a segment of matrix of B of dimension in_len *
 * outlen. One warp is assigned to process one row of A. Each WARP sequentially
 * multiplies one element of A and a row of B to compute partial result of the
 * output. A is loaded in shared memory in a coalesced way. Output matrix is
 * loaded in registers. B should get benefit from L2 cache.
95
 */
Israt Nisa's avatar
Israt Nisa committed
96
template <typename Idx, typename DType>
97
__global__ void GatherMMScatterKernel(
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    const DType* __restrict__ A, const DType* __restrict__ B,
    DType* __restrict__ C, const Idx* __restrict__ idx_a,
    const Idx* __restrict__ idx_b, const Idx* __restrict__ idx_c,
    const int64_t num_rows, const int64_t in_len, const int64_t out_len) {
  unsigned int tId = threadIdx.x;
  unsigned int laneId = tId & 31;
  unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
  unsigned int warpId = gId >> 5;
  unsigned int row = warpId;
  if (row < num_rows) {
    const unsigned int local_row =
        row & 3;  // hardcoded for TB size 128 (4 warps)
    const Idx cur_rowA = (idx_a) ? idx_a[row] : row;
    const Idx cur_rowB = (idx_b) ? idx_b[row] : row;
    const Idx cur_rowC = (idx_c) ? idx_c[row] : row;
    const Idx B_offset = cur_rowB * in_len * out_len;
    const int sh_a_tile = 64;
    __shared__ DType sh_A[4 * sh_a_tile];
    int a_tile = sh_a_tile;
    for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
      if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
      // Load A in shared mem in a coalesced way
      for (unsigned int l = laneId; l < a_tile; l += 32)
        sh_A[local_row * sh_a_tile + l] = A[cur_rowA * in_len + (k_start + l)];
sangwzh's avatar
sangwzh committed
122
      // __syncwarp();
123

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
      for (unsigned int outloop = 0; outloop < out_len; outloop += 32) {
        DType out_reg = static_cast<DType>(0.0f);  // thread private
        const unsigned int l = laneId;
        if (l < out_len) {
          // iterate over elements of a row of A
          for (unsigned int i = 0; i < a_tile; i++) {
            const DType a_val = sh_A[local_row * sh_a_tile + i];
            // iterate over elements of a row of B in parallel
            out_reg +=
                a_val * B[B_offset + ((i + k_start) * out_len + (outloop + l))];
          }
          if (idx_c) {
            AtomicAdd(C + cur_rowC * out_len + (outloop + l), out_reg);
          } else {
            C[cur_rowC * out_len + (outloop + l)] += out_reg;
          }
Israt Nisa's avatar
Israt Nisa committed
140
        }
141
      }
Israt Nisa's avatar
Israt Nisa committed
142
    }
143
  }
Israt Nisa's avatar
Israt Nisa committed
144
145
}

146
/**
147
148
149
150
151
 * @note Output matrix is accumulated via atomic operations. Rest of the
 * strategies are similar to GatherMMKernel. One warp is assigned to process one
 * row of A. Each WARP sequentially multiplies one element of A and a row of B
 * to compute partial result of the output. A is loaded in shared memory in a
 * coalesced way. B should get benefit from L2 cache.
152
 */
Israt Nisa's avatar
Israt Nisa committed
153
template <typename Idx, typename DType>
154
__global__ void GatherMMScatterKernel2(
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    const DType* __restrict__ A, const DType* __restrict__ B,
    DType* __restrict__ C, const Idx* __restrict__ idx_a,
    const Idx* __restrict__ idx_b, const Idx* __restrict__ idx_c,
    const int64_t num_rows, const int64_t in_len, const int64_t out_len) {
  unsigned int tId = threadIdx.x;
  unsigned int laneId = tId & 31;
  unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
  unsigned int warpId = gId >> 5;
  unsigned int row = warpId;
  if (row < num_rows) {
    const unsigned int local_row =
        row & 3;  // hardcoded for TB size 128 (4 warps)
    const Idx row_a = (idx_a) ? idx_a[row] : row;
    const Idx row_b = (idx_b) ? idx_b[row] : row;
    const Idx row_c = (idx_c) ? idx_c[row] : row;
    const Idx C_offset = row_c * in_len * out_len;
    const int sh_a_tile = 64;
    __shared__ DType sh_A[4 * sh_a_tile];
    int a_tile = sh_a_tile;
    for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
      if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
      /* Load A in shared mem in a coalesced way */
      for (unsigned int l = laneId; l < a_tile; l += 32)
        sh_A[local_row * sh_a_tile + l] = A[row_a * in_len + (k_start + l)];
sangwzh's avatar
sangwzh committed
179
      // __syncwarp();
Israt Nisa's avatar
Israt Nisa committed
180

181
182
183
184
185
186
187
188
189
190
191
192
      for (unsigned int outloop = 0; outloop < out_len; outloop += 32) {
        DType out_reg = static_cast<DType>(0.0f);  // thread private
        const unsigned int l = laneId;
        if (l < out_len) {
          const DType b_val = B[row_b * out_len + (outloop + l)];
          /* iterate over elements of a row of A */
          for (unsigned int i = 0; i < a_tile; i++) {
            const DType a_val = sh_A[local_row * sh_a_tile + i];
            const Idx C_idx =
                C_offset + ((i + k_start) * out_len + (outloop + l));
            AtomicAdd(C + C_idx, a_val * b_val);
          }
Israt Nisa's avatar
Israt Nisa committed
193
        }
194
      }
Israt Nisa's avatar
Israt Nisa committed
195
    }
196
  }
Israt Nisa's avatar
Israt Nisa committed
197
198
199
200
}

}  // namespace cuda

201
/**
202
 * @brief Implementation of Gather_mm operator. The input matrix A is
203
 *        expected to be sorted according to relation type.
204
205
206
207
 * @param A The input dense matrix of dimension m x k
 * @param B The input dense matrix of dimension k x n
 * @param C The output dense matrix of dimension m x n
 * @param seglen_A The input vector of size R. Each element
208
 *        is the length of segments of input ``A``
209
210
 * @param a_trans Matrix A to be transposed
 * @param b_trans Matrix B to be transposed
Israt Nisa's avatar
Israt Nisa committed
211
 */
212
template <int XPU, typename IdType, typename DType>
213
214
215
216
void SegmentMM(
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans) {
  auto device = runtime::DeviceAPI::Get(A->ctx);
sangwzh's avatar
sangwzh committed
217
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
218
219
220
221
222
223
224
225
  const DType* A_data = A.Ptr<DType>();
  const DType* B_data = B.Ptr<DType>();
  const IdType* seglen_A_data = seglen_A.Ptr<IdType>();
  DType* C_data = C.Ptr<DType>();
  int64_t A_offset = 0, B_offset = 0, C_offset = 0;
  int64_t m, n, k;
  int64_t num_rel = seglen_A.NumElements();
  DType alpha = 1., beta = 0.;
Israt Nisa's avatar
Israt Nisa committed
226

227
228
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  if (!thr_entry->cublas_handle)
sangwzh's avatar
sangwzh committed
229
230
    CUBLAS_CALL(hipblasCreate(&(thr_entry->cublas_handle)));
  CUBLAS_CALL(hipblasSetStream(thr_entry->cublas_handle, stream));
Israt Nisa's avatar
Israt Nisa committed
231

232
233
234
235
236
237
238
239
  IdType m_offset = 0;
  for (IdType etype = 0; etype < num_rel; ++etype) {
    m = seglen_A_data[etype];  // rows of A
    CHECK_LE(m_offset + m, A->shape[0])
        << "Segment index out of bound of A->shape[0].";
    n = B->shape[2];  // cols of B
    k = B->shape[1];  // cols of A == rows of B
    int ldb = n, lda = k, ldc = n;
sangwzh's avatar
sangwzh committed
240
241
    hipblasOperation_t transB = HIPBLAS_OP_N;
    hipblasOperation_t transA = HIPBLAS_OP_N;
242
    if (b_trans) {
sangwzh's avatar
sangwzh committed
243
      transB = HIPBLAS_OP_T;
244
245
      ldb = n, lda = n, ldc = k;
      std::swap(n, k);
246
    }
247
248
249
250
251
252
253
254
255
    CUBLAS_CALL(cublasGemm<DType>(
        thr_entry->cublas_handle, transB, transA, n, m, k, &alpha,
        B_data + B_offset, ldb, A_data + A_offset, lda, &beta,
        C_data + C_offset, ldc));
    A_offset += m * k;
    B_offset += k * n;
    C_offset += m * n;
    m_offset += m;
  }
Israt Nisa's avatar
Israt Nisa committed
256
257
}

258
template <int XPU, typename IdType, typename DType>
259
260
261
void SegmentMMBackwardB(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
  auto device = runtime::DeviceAPI::Get(A->ctx);
sangwzh's avatar
sangwzh committed
262
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
263
264
265
266
267
268
269
  const DType* A_data = A.Ptr<DType>();
  const DType* dC_data = dC.Ptr<DType>();
  const IdType* seglen_data = seglen.Ptr<IdType>();
  DType* dB_data = dB.Ptr<DType>();
  int64_t A_offset = 0, dC_offset = 0, dB_offset = 0;
  int64_t m, n, k;
  int64_t num_rel = seglen.NumElements();
270
  DType alpha = 1., beta = 0.;
271

272
273
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  if (!thr_entry->cublas_handle)
sangwzh's avatar
sangwzh committed
274
275
    CUBLAS_CALL(hipblasCreate(&(thr_entry->cublas_handle)));
  CUBLAS_CALL(hipblasSetStream(thr_entry->cublas_handle, stream));
276

277
278
279
280
281
282
283
284
  IdType k_offset = 0;
  for (IdType etype = 0; etype < num_rel; ++etype) {
    m = dC->shape[1];
    n = A->shape[1];
    k = seglen_data[etype];
    CHECK_LE(k_offset + k, A->shape[0])
        << "Segement index out of bound of A->shape[0].";
    int lddC = m, ldA = n, lddB = m;
sangwzh's avatar
sangwzh committed
285
286
    hipblasOperation_t trans_dC = HIPBLAS_OP_N;
    hipblasOperation_t trans_A = HIPBLAS_OP_T;
287
288
289
290
291
292
293
294
295
    CUBLAS_CALL(cublasGemm<DType>(
        thr_entry->cublas_handle, trans_dC, trans_A, m, n, k, &alpha,
        dC_data + dC_offset, lddC, A_data + A_offset, ldA, &beta,
        dB_data + dB_offset, lddB));
    dC_offset += m * k;
    A_offset += n * k;
    dB_offset += m * n;
    k_offset += k;
  }
Israt Nisa's avatar
Israt Nisa committed
296
297
}

298
/**
299
 * @brief Implementation of Gather_mm operator. The input matrix A is
Israt Nisa's avatar
Israt Nisa committed
300
 *        expected to be sorted according to relation type.
301
302
303
304
305
 * @param A The input dense matrix of dimension m x k
 * @param B The input dense matrix of dimension k x n
 * @param C The output dense matrix of dimension m x n
 * @param idx_a The input vector to gather left hand operand on
 * @param idx_b The input vector to gather right hand operand on
Israt Nisa's avatar
Israt Nisa committed
306
 */
307

308
template <int XPU, typename IdType, typename DType>
309
310
311
312
void GatherMM(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b) {
  auto device = runtime::DeviceAPI::Get(A->ctx);
sangwzh's avatar
sangwzh committed
313
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
314
315
316
317
318
319
320
321
322
323
324
325
  int64_t out_len = B->shape[2];  // cols of B
  int64_t in_len = A->shape[1];   // cols of A
  const int64_t tot_num_rows = A->shape[0];
  const int ntx = 128;
  const int warp_size = 32;
  const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
  const dim3 nblks(nbx);
  const dim3 nthrs(ntx);
  CUDA_KERNEL_CALL(
      (cuda::GatherMMScatterKernel<IdType, DType>), nblks, nthrs, 0, stream,
      A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),
      idx_b.Ptr<IdType>(), nullptr, tot_num_rows, in_len, out_len);
Israt Nisa's avatar
Israt Nisa committed
326
327
}

328
/**
329
 * @brief Implementation of Gather_mm operator. The input matrix A is
Israt Nisa's avatar
Israt Nisa committed
330
 *        expected to be sorted according to relation type.
331
332
333
334
335
336
337
338
339
 * @param A The input dense matrix of dimension m x k
 * @param B The input dense matrix of dimension k x n
 * @param C The output dense matrix of dimension m x n
 * @param idx_a The input vector to gather left hand operand on
 * @param idx_b The input vector to gather right hand operand on
 * @param idx_c The input vector to gather output operand on
 * @param num_rel The number of idx types in idx_b
 * @param a_trans Matrix A to be transposed
 * @param b_trans Matrix B to be transposed
Israt Nisa's avatar
Israt Nisa committed
340
 */
341
template <int XPU, typename IdType, typename DType>
342
343
344
345
void GatherMMScatter(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c) {
  auto device = runtime::DeviceAPI::Get(A->ctx);
sangwzh's avatar
sangwzh committed
346
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
  const IdType* idx_c_data = idx_c.Ptr<IdType>();
  int64_t out_len = (B->ndim == 2) ? B->shape[1] : B->shape[2];  // cols of B
  int64_t in_len = A->shape[1];                                  // cols of A
  int64_t tot_num_rows = A->shape[0];
  const int ntx = 128;
  const int warp_size = 32;
  const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
  const dim3 nblks(nbx);
  const dim3 nthrs(ntx);
  if (B->ndim == 3) {
    CUDA_KERNEL_CALL(
        (cuda::GatherMMScatterKernel<IdType, DType>), nblks, nthrs, 0, stream,
        A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),
        idx_b.Ptr<IdType>(), idx_c.Ptr<IdType>(), tot_num_rows, in_len,
        out_len);
  } else {
    // Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i]
    // This kernel accesses rows of A in a transposed way w/o explicitly
    // converting A
    CUDA_KERNEL_CALL(
        (cuda::GatherMMScatterKernel2<IdType, DType>), nblks, nthrs, 0, stream,
        A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),
        idx_b.Ptr<IdType>(), idx_c.Ptr<IdType>(), tot_num_rows, in_len,
        out_len);
  }
Israt Nisa's avatar
Israt Nisa committed
372
373
}

374
template void GatherMM<kDGLCUDA, int32_t, __half>(
375
376
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
377
template void GatherMM<kDGLCUDA, int64_t, __half>(
378
379
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
380
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
381
template void GatherMM<kDGLCUDA, int32_t, __hip_bfloat16>(
382
383
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
sangwzh's avatar
sangwzh committed
384
template void GatherMM<kDGLCUDA, int64_t, __hip_bfloat16>(
385
386
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
387
388
#endif  // BF16_ENABLED
template void GatherMM<kDGLCUDA, int32_t, float>(
389
390
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
391
template void GatherMM<kDGLCUDA, int64_t, float>(
392
393
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
394
template void GatherMM<kDGLCUDA, int32_t, double>(
395
396
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
397
template void GatherMM<kDGLCUDA, int64_t, double>(
398
399
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
Israt Nisa's avatar
Israt Nisa committed
400

401
template void GatherMMScatter<kDGLCUDA, int32_t, __half>(
402
403
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
404
template void GatherMMScatter<kDGLCUDA, int64_t, __half>(
405
406
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
407
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
408
template void GatherMMScatter<kDGLCUDA, int32_t, __hip_bfloat16>(
409
410
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
sangwzh's avatar
sangwzh committed
411
template void GatherMMScatter<kDGLCUDA, int64_t, __hip_bfloat16>(
412
413
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
414
415
#endif  // BF16_ENABLED
template void GatherMMScatter<kDGLCUDA, int32_t, float>(
416
417
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
418
template void GatherMMScatter<kDGLCUDA, int64_t, float>(
419
420
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
421
template void GatherMMScatter<kDGLCUDA, int32_t, double>(
422
423
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
424
template void GatherMMScatter<kDGLCUDA, int64_t, double>(
425
426
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
Israt Nisa's avatar
Israt Nisa committed
427

428
template void SegmentMM<kDGLCUDA, int32_t, __half>(
429
430
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
431
template void SegmentMM<kDGLCUDA, int64_t, __half>(
432
433
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
434
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
435
template void SegmentMM<kDGLCUDA, int32_t, __hip_bfloat16>(
436
437
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
sangwzh's avatar
sangwzh committed
438
template void SegmentMM<kDGLCUDA, int64_t, __hip_bfloat16>(
439
440
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
441
442
#endif  // BF16_ENABLED
template void SegmentMM<kDGLCUDA, int32_t, float>(
443
444
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
445
template void SegmentMM<kDGLCUDA, int64_t, float>(
446
447
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
448
template void SegmentMM<kDGLCUDA, int32_t, double>(
449
450
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
451
template void SegmentMM<kDGLCUDA, int64_t, double>(
452
453
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
Israt Nisa's avatar
Israt Nisa committed
454

455
456
457
458
459
template void SegmentMMBackwardB<kDGLCUDA, int32_t, __half>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCUDA, int64_t, __half>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
460
template void SegmentMMBackwardB<kDGLCUDA, int32_t, __hip_bfloat16>(
461
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
sangwzh's avatar
sangwzh committed
462
template void SegmentMMBackwardB<kDGLCUDA, int64_t, __hip_bfloat16>(
463
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
464
465
#endif  // BF16_ENABLED
template void SegmentMMBackwardB<kDGLCUDA, int32_t, float>(
466
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
467
template void SegmentMMBackwardB<kDGLCUDA, int64_t, float>(
468
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
469
template void SegmentMMBackwardB<kDGLCUDA, int32_t, double>(
470
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
471
template void SegmentMMBackwardB<kDGLCUDA, int64_t, double>(
472
473
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);

Israt Nisa's avatar
Israt Nisa committed
474
475
}  // namespace aten
}  // namespace dgl