gather_mm.cu 19 KB
Newer Older
1
/**
Israt Nisa's avatar
Israt Nisa committed
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/gather_mm.cu
 * @brief GatherMM C APIs and definitions.
Israt Nisa's avatar
Israt Nisa committed
5
6
 */
#include <dgl/array.h>
7

Israt Nisa's avatar
Israt Nisa committed
8
#include <algorithm>  // std::swap
9

Israt Nisa's avatar
Israt Nisa committed
10
#include "./atomic.cuh"
11
12
#include "./functor.cuh"
#include "./utils.h"
Israt Nisa's avatar
Israt Nisa committed
13
14
15
16
17
18
19

namespace dgl {
using namespace cuda;
namespace aten {

namespace {

20
21
/** @brief Call cuBLAS GEMM API for dense matmul operation for float and double.
 */
Israt Nisa's avatar
Israt Nisa committed
22
template <typename DType>
23
24
25
26
cublasStatus_t cublasGemm(
    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
    int m, int n, int k, const DType* alpha, const DType* A, int lda,
    const DType* B, int ldb, const DType* beta, DType* C, int ldc) {
Israt Nisa's avatar
Israt Nisa committed
27
28
29
30
  LOG(INFO) << "Not supported dtype";
  return CUBLAS_STATUS_EXECUTION_FAILED;
}

31
template <>
32
33
34
35
36
37
cublasStatus_t cublasGemm<__half>(
    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
    int m, int n, int k, const __half* alpha, const __half* A, int lda,
    const __half* B, int ldb, const __half* beta, __half* C, int ldc) {
  return cublasHgemm(
      handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
38
}
39
40
41

#if BF16_ENABLED
template <>
42
43
44
45
cublasStatus_t cublasGemm<__nv_bfloat16>(
    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
    int m, int n, int k, const __nv_bfloat16* alpha, const __nv_bfloat16* A,
    int lda, const __nv_bfloat16* B, int ldb, const __nv_bfloat16* beta,
46
47
48
    __nv_bfloat16* C, int ldc) {
  float alpha_float = __bfloat162float(*alpha);
  float beta_float = __bfloat162float(*beta);
49
50
51
52
  return cublasGemmEx(
      handle, transa, transb, m, n, k, &alpha_float, A, CUDA_R_16BF, lda, B,
      CUDA_R_16BF, ldb, &beta_float, C, CUDA_R_16BF, ldc, CUBLAS_COMPUTE_32F,
      CUBLAS_GEMM_DEFAULT_TENSOR_OP);
53
54
}
#endif  // BF16_ENABLED
55

Israt Nisa's avatar
Israt Nisa committed
56
template <>
57
58
59
60
61
62
cublasStatus_t cublasGemm<float>(
    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
    int m, int n, int k, const float* alpha, const float* A, int lda,
    const float* B, int ldb, const float* beta, float* C, int ldc) {
  return cublasSgemm(
      handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
Israt Nisa's avatar
Israt Nisa committed
63
64
65
}

template <>
66
67
68
69
70
71
cublasStatus_t cublasGemm<double>(
    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
    int m, int n, int k, const double* alpha, const double* A, int lda,
    const double* B, int ldb, const double* beta, double* C, int ldc) {
  return cublasDgemm(
      handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
Israt Nisa's avatar
Israt Nisa committed
72
73
74
75
76
77
}

}  // namespace

namespace cuda {

78
/**
79
80
81
82
83
 * @note Each row of A multiplies a segment of matrix of B of dimension in_len *
 * outlen. One warp is assigned to process one row of A. Each WARP sequentially
 * multiplies one element of A and a row of B to compute partial result of the
 * output. A is loaded in shared memory in a coalesced way. Output matrix is
 * loaded in registers. B should get benefit from L2 cache.
84
 */
Israt Nisa's avatar
Israt Nisa committed
85
template <typename Idx, typename DType>
86
__global__ void GatherMMScatterKernel(
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    const DType* __restrict__ A, const DType* __restrict__ B,
    DType* __restrict__ C, const Idx* __restrict__ idx_a,
    const Idx* __restrict__ idx_b, const Idx* __restrict__ idx_c,
    const int64_t num_rows, const int64_t in_len, const int64_t out_len) {
  unsigned int tId = threadIdx.x;
  unsigned int laneId = tId & 31;
  unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
  unsigned int warpId = gId >> 5;
  unsigned int row = warpId;
  if (row < num_rows) {
    const unsigned int local_row =
        row & 3;  // hardcoded for TB size 128 (4 warps)
    const Idx cur_rowA = (idx_a) ? idx_a[row] : row;
    const Idx cur_rowB = (idx_b) ? idx_b[row] : row;
    const Idx cur_rowC = (idx_c) ? idx_c[row] : row;
    const Idx B_offset = cur_rowB * in_len * out_len;
    const int sh_a_tile = 64;
    __shared__ DType sh_A[4 * sh_a_tile];
    int a_tile = sh_a_tile;
    for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
      if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
      // Load A in shared mem in a coalesced way
      for (unsigned int l = laneId; l < a_tile; l += 32)
        sh_A[local_row * sh_a_tile + l] = A[cur_rowA * in_len + (k_start + l)];
      __syncwarp();
112

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
      for (unsigned int outloop = 0; outloop < out_len; outloop += 32) {
        DType out_reg = static_cast<DType>(0.0f);  // thread private
        const unsigned int l = laneId;
        if (l < out_len) {
          // iterate over elements of a row of A
          for (unsigned int i = 0; i < a_tile; i++) {
            const DType a_val = sh_A[local_row * sh_a_tile + i];
            // iterate over elements of a row of B in parallel
            out_reg +=
                a_val * B[B_offset + ((i + k_start) * out_len + (outloop + l))];
          }
          if (idx_c) {
            AtomicAdd(C + cur_rowC * out_len + (outloop + l), out_reg);
          } else {
            C[cur_rowC * out_len + (outloop + l)] += out_reg;
          }
Israt Nisa's avatar
Israt Nisa committed
129
        }
130
      }
Israt Nisa's avatar
Israt Nisa committed
131
    }
132
  }
Israt Nisa's avatar
Israt Nisa committed
133
134
}

135
/**
136
137
138
139
140
 * @note Output matrix is accumulated via atomic operations. Rest of the
 * strategies are similar to GatherMMKernel. One warp is assigned to process one
 * row of A. Each WARP sequentially multiplies one element of A and a row of B
 * to compute partial result of the output. A is loaded in shared memory in a
 * coalesced way. B should get benefit from L2 cache.
141
 */
Israt Nisa's avatar
Israt Nisa committed
142
template <typename Idx, typename DType>
143
__global__ void GatherMMScatterKernel2(
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    const DType* __restrict__ A, const DType* __restrict__ B,
    DType* __restrict__ C, const Idx* __restrict__ idx_a,
    const Idx* __restrict__ idx_b, const Idx* __restrict__ idx_c,
    const int64_t num_rows, const int64_t in_len, const int64_t out_len) {
  unsigned int tId = threadIdx.x;
  unsigned int laneId = tId & 31;
  unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
  unsigned int warpId = gId >> 5;
  unsigned int row = warpId;
  if (row < num_rows) {
    const unsigned int local_row =
        row & 3;  // hardcoded for TB size 128 (4 warps)
    const Idx row_a = (idx_a) ? idx_a[row] : row;
    const Idx row_b = (idx_b) ? idx_b[row] : row;
    const Idx row_c = (idx_c) ? idx_c[row] : row;
    const Idx C_offset = row_c * in_len * out_len;
    const int sh_a_tile = 64;
    __shared__ DType sh_A[4 * sh_a_tile];
    int a_tile = sh_a_tile;
    for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
      if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
      /* Load A in shared mem in a coalesced way */
      for (unsigned int l = laneId; l < a_tile; l += 32)
        sh_A[local_row * sh_a_tile + l] = A[row_a * in_len + (k_start + l)];
      __syncwarp();
Israt Nisa's avatar
Israt Nisa committed
169

170
171
172
173
174
175
176
177
178
179
180
181
      for (unsigned int outloop = 0; outloop < out_len; outloop += 32) {
        DType out_reg = static_cast<DType>(0.0f);  // thread private
        const unsigned int l = laneId;
        if (l < out_len) {
          const DType b_val = B[row_b * out_len + (outloop + l)];
          /* iterate over elements of a row of A */
          for (unsigned int i = 0; i < a_tile; i++) {
            const DType a_val = sh_A[local_row * sh_a_tile + i];
            const Idx C_idx =
                C_offset + ((i + k_start) * out_len + (outloop + l));
            AtomicAdd(C + C_idx, a_val * b_val);
          }
Israt Nisa's avatar
Israt Nisa committed
182
        }
183
      }
Israt Nisa's avatar
Israt Nisa committed
184
    }
185
  }
Israt Nisa's avatar
Israt Nisa committed
186
187
188
189
}

}  // namespace cuda

190
/**
191
 * @brief Implementation of Gather_mm operator. The input matrix A is
192
 *        expected to be sorted according to relation type.
193
194
195
196
 * @param A The input dense matrix of dimension m x k
 * @param B The input dense matrix of dimension k x n
 * @param C The output dense matrix of dimension m x n
 * @param seglen_A The input vector of size R. Each element
197
 *        is the length of segments of input ``A``
198
199
 * @param a_trans Matrix A to be transposed
 * @param b_trans Matrix B to be transposed
Israt Nisa's avatar
Israt Nisa committed
200
 */
201
template <int XPU, typename IdType, typename DType>
202
203
204
205
206
207
208
209
210
211
212
213
214
void SegmentMM(
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans) {
  auto device = runtime::DeviceAPI::Get(A->ctx);
  cudaStream_t stream = runtime::getCurrentCUDAStream();
  const DType* A_data = A.Ptr<DType>();
  const DType* B_data = B.Ptr<DType>();
  const IdType* seglen_A_data = seglen_A.Ptr<IdType>();
  DType* C_data = C.Ptr<DType>();
  int64_t A_offset = 0, B_offset = 0, C_offset = 0;
  int64_t m, n, k;
  int64_t num_rel = seglen_A.NumElements();
  DType alpha = 1., beta = 0.;
Israt Nisa's avatar
Israt Nisa committed
215

216
217
218
219
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  if (!thr_entry->cublas_handle)
    CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
  CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
Israt Nisa's avatar
Israt Nisa committed
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
  IdType m_offset = 0;
  for (IdType etype = 0; etype < num_rel; ++etype) {
    m = seglen_A_data[etype];  // rows of A
    CHECK_LE(m_offset + m, A->shape[0])
        << "Segment index out of bound of A->shape[0].";
    n = B->shape[2];  // cols of B
    k = B->shape[1];  // cols of A == rows of B
    int ldb = n, lda = k, ldc = n;
    cublasOperation_t transB = CUBLAS_OP_N;
    cublasOperation_t transA = CUBLAS_OP_N;
    if (b_trans) {
      transB = CUBLAS_OP_T;
      ldb = n, lda = n, ldc = k;
      std::swap(n, k);
235
    }
236
237
238
239
240
241
242
243
244
    CUBLAS_CALL(cublasGemm<DType>(
        thr_entry->cublas_handle, transB, transA, n, m, k, &alpha,
        B_data + B_offset, ldb, A_data + A_offset, lda, &beta,
        C_data + C_offset, ldc));
    A_offset += m * k;
    B_offset += k * n;
    C_offset += m * n;
    m_offset += m;
  }
Israt Nisa's avatar
Israt Nisa committed
245
246
}

247
template <int XPU, typename IdType, typename DType>
248
249
250
251
252
253
254
255
256
257
258
void SegmentMMBackwardB(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
  auto device = runtime::DeviceAPI::Get(A->ctx);
  cudaStream_t stream = runtime::getCurrentCUDAStream();
  const DType* A_data = A.Ptr<DType>();
  const DType* dC_data = dC.Ptr<DType>();
  const IdType* seglen_data = seglen.Ptr<IdType>();
  DType* dB_data = dB.Ptr<DType>();
  int64_t A_offset = 0, dC_offset = 0, dB_offset = 0;
  int64_t m, n, k;
  int64_t num_rel = seglen.NumElements();
259
  DType alpha = 1., beta = 0.;
260

261
262
263
264
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  if (!thr_entry->cublas_handle)
    CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
  CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
265

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
  IdType k_offset = 0;
  for (IdType etype = 0; etype < num_rel; ++etype) {
    m = dC->shape[1];
    n = A->shape[1];
    k = seglen_data[etype];
    CHECK_LE(k_offset + k, A->shape[0])
        << "Segement index out of bound of A->shape[0].";
    int lddC = m, ldA = n, lddB = m;
    cublasOperation_t trans_dC = CUBLAS_OP_N;
    cublasOperation_t trans_A = CUBLAS_OP_T;
    CUBLAS_CALL(cublasGemm<DType>(
        thr_entry->cublas_handle, trans_dC, trans_A, m, n, k, &alpha,
        dC_data + dC_offset, lddC, A_data + A_offset, ldA, &beta,
        dB_data + dB_offset, lddB));
    dC_offset += m * k;
    A_offset += n * k;
    dB_offset += m * n;
    k_offset += k;
  }
Israt Nisa's avatar
Israt Nisa committed
285
286
}

287
/**
288
 * @brief Implementation of Gather_mm operator. The input matrix A is
Israt Nisa's avatar
Israt Nisa committed
289
 *        expected to be sorted according to relation type.
290
291
292
293
294
 * @param A The input dense matrix of dimension m x k
 * @param B The input dense matrix of dimension k x n
 * @param C The output dense matrix of dimension m x n
 * @param idx_a The input vector to gather left hand operand on
 * @param idx_b The input vector to gather right hand operand on
Israt Nisa's avatar
Israt Nisa committed
295
 */
296

297
template <int XPU, typename IdType, typename DType>
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
void GatherMM(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b) {
  auto device = runtime::DeviceAPI::Get(A->ctx);
  cudaStream_t stream = runtime::getCurrentCUDAStream();
  int64_t out_len = B->shape[2];  // cols of B
  int64_t in_len = A->shape[1];   // cols of A
  const int64_t tot_num_rows = A->shape[0];
  const int ntx = 128;
  const int warp_size = 32;
  const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
  const dim3 nblks(nbx);
  const dim3 nthrs(ntx);
  CUDA_KERNEL_CALL(
      (cuda::GatherMMScatterKernel<IdType, DType>), nblks, nthrs, 0, stream,
      A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),
      idx_b.Ptr<IdType>(), nullptr, tot_num_rows, in_len, out_len);
Israt Nisa's avatar
Israt Nisa committed
315
316
}

317
/**
318
 * @brief Implementation of Gather_mm operator. The input matrix A is
Israt Nisa's avatar
Israt Nisa committed
319
 *        expected to be sorted according to relation type.
320
321
322
323
324
325
326
327
328
 * @param A The input dense matrix of dimension m x k
 * @param B The input dense matrix of dimension k x n
 * @param C The output dense matrix of dimension m x n
 * @param idx_a The input vector to gather left hand operand on
 * @param idx_b The input vector to gather right hand operand on
 * @param idx_c The input vector to gather output operand on
 * @param num_rel The number of idx types in idx_b
 * @param a_trans Matrix A to be transposed
 * @param b_trans Matrix B to be transposed
Israt Nisa's avatar
Israt Nisa committed
329
 */
330
template <int XPU, typename IdType, typename DType>
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
void GatherMMScatter(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c) {
  auto device = runtime::DeviceAPI::Get(A->ctx);
  cudaStream_t stream = runtime::getCurrentCUDAStream();
  const IdType* idx_c_data = idx_c.Ptr<IdType>();
  int64_t out_len = (B->ndim == 2) ? B->shape[1] : B->shape[2];  // cols of B
  int64_t in_len = A->shape[1];                                  // cols of A
  int64_t tot_num_rows = A->shape[0];
  const int ntx = 128;
  const int warp_size = 32;
  const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
  const dim3 nblks(nbx);
  const dim3 nthrs(ntx);
  if (B->ndim == 3) {
    CUDA_KERNEL_CALL(
        (cuda::GatherMMScatterKernel<IdType, DType>), nblks, nthrs, 0, stream,
        A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),
        idx_b.Ptr<IdType>(), idx_c.Ptr<IdType>(), tot_num_rows, in_len,
        out_len);
  } else {
    // Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i]
    // This kernel accesses rows of A in a transposed way w/o explicitly
    // converting A
    CUDA_KERNEL_CALL(
        (cuda::GatherMMScatterKernel2<IdType, DType>), nblks, nthrs, 0, stream,
        A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),
        idx_b.Ptr<IdType>(), idx_c.Ptr<IdType>(), tot_num_rows, in_len,
        out_len);
  }
Israt Nisa's avatar
Israt Nisa committed
361
362
}

363
template void GatherMM<kDGLCUDA, int32_t, __half>(
364
365
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
366
template void GatherMM<kDGLCUDA, int64_t, __half>(
367
368
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
369
370
#if BF16_ENABLED
template void GatherMM<kDGLCUDA, int32_t, __nv_bfloat16>(
371
372
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
373
template void GatherMM<kDGLCUDA, int64_t, __nv_bfloat16>(
374
375
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
376
377
#endif  // BF16_ENABLED
template void GatherMM<kDGLCUDA, int32_t, float>(
378
379
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
380
template void GatherMM<kDGLCUDA, int64_t, float>(
381
382
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
383
template void GatherMM<kDGLCUDA, int32_t, double>(
384
385
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
386
template void GatherMM<kDGLCUDA, int64_t, double>(
387
388
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
Israt Nisa's avatar
Israt Nisa committed
389

390
template void GatherMMScatter<kDGLCUDA, int32_t, __half>(
391
392
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
393
template void GatherMMScatter<kDGLCUDA, int64_t, __half>(
394
395
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
396
397
#if BF16_ENABLED
template void GatherMMScatter<kDGLCUDA, int32_t, __nv_bfloat16>(
398
399
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
400
template void GatherMMScatter<kDGLCUDA, int64_t, __nv_bfloat16>(
401
402
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
403
404
#endif  // BF16_ENABLED
template void GatherMMScatter<kDGLCUDA, int32_t, float>(
405
406
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
407
template void GatherMMScatter<kDGLCUDA, int64_t, float>(
408
409
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
410
template void GatherMMScatter<kDGLCUDA, int32_t, double>(
411
412
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
413
template void GatherMMScatter<kDGLCUDA, int64_t, double>(
414
415
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
Israt Nisa's avatar
Israt Nisa committed
416

417
template void SegmentMM<kDGLCUDA, int32_t, __half>(
418
419
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
420
template void SegmentMM<kDGLCUDA, int64_t, __half>(
421
422
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
423
424
#if BF16_ENABLED
template void SegmentMM<kDGLCUDA, int32_t, __nv_bfloat16>(
425
426
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
427
template void SegmentMM<kDGLCUDA, int64_t, __nv_bfloat16>(
428
429
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
430
431
#endif  // BF16_ENABLED
template void SegmentMM<kDGLCUDA, int32_t, float>(
432
433
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
434
template void SegmentMM<kDGLCUDA, int64_t, float>(
435
436
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
437
template void SegmentMM<kDGLCUDA, int32_t, double>(
438
439
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
440
template void SegmentMM<kDGLCUDA, int64_t, double>(
441
442
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
Israt Nisa's avatar
Israt Nisa committed
443

444
445
446
447
448
449
template void SegmentMMBackwardB<kDGLCUDA, int32_t, __half>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCUDA, int64_t, __half>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
#if BF16_ENABLED
template void SegmentMMBackwardB<kDGLCUDA, int32_t, __nv_bfloat16>(
450
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
451
template void SegmentMMBackwardB<kDGLCUDA, int64_t, __nv_bfloat16>(
452
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
453
454
#endif  // BF16_ENABLED
template void SegmentMMBackwardB<kDGLCUDA, int32_t, float>(
455
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
456
template void SegmentMMBackwardB<kDGLCUDA, int64_t, float>(
457
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
458
template void SegmentMMBackwardB<kDGLCUDA, int32_t, double>(
459
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
460
template void SegmentMMBackwardB<kDGLCUDA, int64_t, double>(
461
462
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);

Israt Nisa's avatar
Israt Nisa committed
463
464
}  // namespace aten
}  // namespace dgl