"vscode:/vscode.git/clone" did not exist on "a6336f7a9cd23c93c6b665aedc4cb0d91075af03"
gather_mm.cu 20.3 KB
Newer Older
1
/**
Israt Nisa's avatar
Israt Nisa committed
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/gather_mm.cu
 * @brief GatherMM C APIs and definitions.
Israt Nisa's avatar
Israt Nisa committed
5
6
7
8
9
10
11
12
13
14
15
16
17
 */
#include <dgl/array.h>
#include <algorithm>  // std::swap
#include "./utils.h"
#include "./functor.cuh"
#include "./atomic.cuh"

namespace dgl {
using namespace cuda;
namespace aten {

namespace {

18
/** @brief Call cuBLAS GEMM API for dense matmul operation for float and double. */
Israt Nisa's avatar
Israt Nisa committed
19
20
21
22
23
24
25
26
27
28
template <typename DType>
cublasStatus_t cublasGemm(cublasHandle_t handle, cublasOperation_t transa,
    cublasOperation_t transb, int m, int n, int k,
    const DType* alpha, const DType* A, int lda,
    const DType* B, int ldb, const DType* beta,
    DType* C, int ldc) {
  LOG(INFO) << "Not supported dtype";
  return CUBLAS_STATUS_EXECUTION_FAILED;
}

29
30
31
32
33
34
35
36
37
template <>
cublasStatus_t cublasGemm<__half>(cublasHandle_t handle, cublasOperation_t transa,
    cublasOperation_t transb, int m, int n, int k,
    const __half* alpha, const __half* A, int lda,
    const __half* B, int ldb, const __half* beta,
    __half* C, int ldc) {
  return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda,
      B, ldb, beta, C, ldc);
}
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

#if BF16_ENABLED
template <>
cublasStatus_t cublasGemm<__nv_bfloat16>(cublasHandle_t handle, cublasOperation_t transa,
    cublasOperation_t transb, int m, int n, int k,
    const __nv_bfloat16* alpha, const __nv_bfloat16* A, int lda,
    const __nv_bfloat16* B, int ldb, const __nv_bfloat16* beta,
    __nv_bfloat16* C, int ldc) {
  float alpha_float = __bfloat162float(*alpha);
  float beta_float = __bfloat162float(*beta);
  return cublasGemmEx(handle, transa, transb, m, n, k,
      &alpha_float, A, CUDA_R_16BF, lda,
      B, CUDA_R_16BF, ldb,
      &beta_float, C, CUDA_R_16BF, ldc,
      CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
#endif  // BF16_ENABLED
55

Israt Nisa's avatar
Israt Nisa committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
template <>
cublasStatus_t cublasGemm<float>(cublasHandle_t handle, cublasOperation_t transa,
    cublasOperation_t transb, int m, int n, int k,
    const float* alpha, const float* A, int lda,
    const float* B, int ldb, const float* beta,
    float* C, int ldc) {
  return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda,
      B, ldb, beta, C, ldc);
}

template <>
cublasStatus_t cublasGemm<double>(cublasHandle_t handle, cublasOperation_t transa,
    cublasOperation_t transb, int m, int n, int k,
    const double* alpha, const double* A, int lda,
    const double* B, int ldb, const double* beta,
    double* C, int ldc) {
  return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda,
      B, ldb, beta, C, ldc);
}

}  // namespace

namespace cuda {

80
81
82
83
84
85
86
/**
 * @note Each row of A multiplies a segment of matrix of B of dimension in_len * outlen.
 * One warp is assigned to process one row of A. Each WARP sequentially multiplies
 * one element of A and a row of B to compute partial result of the output. A
 * is loaded in shared memory in a coalesced way. Output matrix is loaded in
 * registers. B should get benefit from L2 cache.
 */
Israt Nisa's avatar
Israt Nisa committed
87
template <typename Idx, typename DType>
88
__global__ void GatherMMScatterKernel(
Israt Nisa's avatar
Israt Nisa committed
89
90
91
92
93
    const DType* __restrict__ A,
    const DType* __restrict__ B,
    DType* __restrict__ C,
    const Idx* __restrict__ idx_a,
    const Idx* __restrict__ idx_b,
94
95
96
97
98
    const Idx* __restrict__ idx_c,
    const int64_t num_rows,
    const int64_t in_len,
    const int64_t out_len) {

Israt Nisa's avatar
Israt Nisa committed
99
100
101
102
103
104
    unsigned int tId = threadIdx.x;
    unsigned int laneId = tId & 31;
    unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
    unsigned int warpId = gId >> 5;
    unsigned int row = warpId;
    if (row < num_rows) {
105
106
107
108
109
        const unsigned int local_row = row & 3;  // hardcoded for TB size 128 (4 warps)
        const Idx cur_rowA = (idx_a) ? idx_a[row] : row;
        const Idx cur_rowB = (idx_b) ? idx_b[row] : row;
        const Idx cur_rowC = (idx_c) ? idx_c[row] : row;
        const Idx B_offset = cur_rowB * in_len * out_len;
Israt Nisa's avatar
Israt Nisa committed
110
111
112
113
114
        const int sh_a_tile = 64;
        __shared__ DType sh_A[4 * sh_a_tile];
        int a_tile = sh_a_tile;
        for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
            if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
115
            // Load A in shared mem in a coalesced way
Israt Nisa's avatar
Israt Nisa committed
116
117
118
119
120
            for (unsigned int l = laneId; l < a_tile; l += 32)
                sh_A[local_row * sh_a_tile + l] = A[cur_rowA * in_len + (k_start + l)];
            __syncwarp();

            for (unsigned int outloop = 0; outloop < out_len; outloop +=32) {
121
                DType out_reg = static_cast<DType>(0.0f);  // thread private
Israt Nisa's avatar
Israt Nisa committed
122
123
                const unsigned int l = laneId;
                if (l < out_len) {
124
                    // iterate over elements of a row of A
Israt Nisa's avatar
Israt Nisa committed
125
126
                    for (unsigned int i = 0; i < a_tile; i++) {
                        const DType a_val =  sh_A[local_row * sh_a_tile + i];
127
                        // iterate over elements of a row of B in parallel
Israt Nisa's avatar
Israt Nisa committed
128
129
                        out_reg += a_val * B[B_offset + ((i + k_start) * out_len + (outloop + l))];
                    }
130
131
132
133
134
                    if (idx_c) {
                      AtomicAdd(C + cur_rowC * out_len + (outloop + l), out_reg);
                    } else {
                      C[cur_rowC * out_len + (outloop + l)] += out_reg;
                    }
Israt Nisa's avatar
Israt Nisa committed
135
136
137
138
139
140
                }
            }
        }
    }
}

141

142
143
144
145
146
147
148
/**
 * @note Output matrix is accumulated via atomic operations. Rest of the strategies
 * are similar to GatherMMKernel. One warp is assigned to process one row of A. Each
 * WARP sequentially multiplies one element of A and a row of B to compute partial
 * result of the output. A is loaded in shared memory in a coalesced way. B should
 * get benefit from L2 cache.
 */
Israt Nisa's avatar
Israt Nisa committed
149
template <typename Idx, typename DType>
150
__global__ void GatherMMScatterKernel2(
Israt Nisa's avatar
Israt Nisa committed
151
152
153
154
155
156
    const DType* __restrict__ A,
    const DType* __restrict__ B,
    DType* __restrict__ C,
    const Idx* __restrict__ idx_a,
    const Idx* __restrict__ idx_b,
    const Idx* __restrict__ idx_c,
157
158
159
160
    const int64_t num_rows,
    const int64_t in_len,
    const int64_t out_len) {

Israt Nisa's avatar
Israt Nisa committed
161
162
163
164
165
166
    unsigned int tId = threadIdx.x;
    unsigned int laneId = tId & 31;
    unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
    unsigned int warpId = gId >> 5;
    unsigned int row = warpId;
    if (row < num_rows) {
167
168
169
170
171
        const unsigned int local_row = row & 3;  // hardcoded for TB size 128 (4 warps)
        const Idx row_a = (idx_a) ? idx_a[row] : row;
        const Idx row_b = (idx_b) ? idx_b[row] : row;
        const Idx row_c = (idx_c) ? idx_c[row] : row;
        const Idx C_offset = row_c * in_len * out_len;
Israt Nisa's avatar
Israt Nisa committed
172
173
174
175
176
177
178
179
180
181
182
        const int sh_a_tile = 64;
        __shared__ DType sh_A[4 * sh_a_tile];
        int a_tile = sh_a_tile;
        for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
            if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
            /* Load A in shared mem in a coalesced way */
            for (unsigned int l = laneId; l < a_tile; l += 32)
                sh_A[local_row * sh_a_tile + l] = A[row_a * in_len + (k_start + l)];
            __syncwarp();

            for (unsigned int outloop = 0; outloop < out_len; outloop +=32) {
183
                DType out_reg = static_cast<DType>(0.0f);  // thread private
Israt Nisa's avatar
Israt Nisa committed
184
185
186
187
188
189
190
                const unsigned int l = laneId;
                if (l < out_len) {
                    const DType b_val = B[row_b * out_len + (outloop + l)];
                    /* iterate over elements of a row of A */
                    for (unsigned int i = 0; i < a_tile; i++) {
                        const DType a_val = sh_A[local_row * sh_a_tile + i];
                        const Idx C_idx = C_offset + ((i + k_start) * out_len + (outloop + l));
191
                        AtomicAdd(C + C_idx, a_val * b_val);
Israt Nisa's avatar
Israt Nisa committed
192
193
194
195
196
197
198
199
200
                    }
                }
            }
        }
    }
}

}  // namespace cuda

201
/**
202
 * @brief Implementation of Gather_mm operator. The input matrix A is
203
 *        expected to be sorted according to relation type.
204
205
206
207
 * @param A The input dense matrix of dimension m x k
 * @param B The input dense matrix of dimension k x n
 * @param C The output dense matrix of dimension m x n
 * @param seglen_A The input vector of size R. Each element
208
 *        is the length of segments of input ``A``
209
210
 * @param a_trans Matrix A to be transposed
 * @param b_trans Matrix B to be transposed
Israt Nisa's avatar
Israt Nisa committed
211
 */
212
template <int XPU, typename IdType, typename DType>
213
214
215
216
217
void SegmentMM(const NDArray A,
               const NDArray B,
               NDArray C,
               const NDArray seglen_A,
               bool a_trans, bool b_trans) {
218
219
220
221
222
223
224
225
226
227
    auto device = runtime::DeviceAPI::Get(A->ctx);
    cudaStream_t stream = runtime::getCurrentCUDAStream();
    const DType *A_data = A.Ptr<DType>();
    const DType *B_data = B.Ptr<DType>();
    const IdType* seglen_A_data = seglen_A.Ptr<IdType>();
    DType *C_data = C.Ptr<DType>();
    int64_t A_offset = 0, B_offset = 0, C_offset = 0;
    int64_t m, n, k;
    int64_t num_rel = seglen_A.NumElements();
    DType alpha = 1., beta = 0.;
Israt Nisa's avatar
Israt Nisa committed
228

229
230
231
232
    auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
    if (!thr_entry->cublas_handle)
        CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
    CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
Israt Nisa's avatar
Israt Nisa committed
233

234
235
236
237
238
239
240
241
242
243
244
245
246
    IdType m_offset = 0;
    for (IdType etype = 0; etype < num_rel; ++etype) {
        m = seglen_A_data[etype];  // rows of A
        CHECK_LE(m_offset + m, A->shape[0]) << "Segment index out of bound of A->shape[0].";
        n = B->shape[2];  // cols of B
        k = B->shape[1];  // cols of A == rows of B
        int ldb = n, lda = k, ldc = n;
        cublasOperation_t transB = CUBLAS_OP_N;
        cublasOperation_t transA = CUBLAS_OP_N;
        if (b_trans) {
            transB = CUBLAS_OP_T;
            ldb = n, lda = n, ldc = k;
            std::swap(n, k);
Israt Nisa's avatar
Israt Nisa committed
247
        }
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        CUBLAS_CALL(cublasGemm<DType>(
            thr_entry->cublas_handle,
            transB,
            transA,
            n, m, k,
            &alpha,
            B_data + B_offset, ldb,
            A_data + A_offset, lda,
            &beta,
            C_data + C_offset, ldc));
        A_offset += m * k;
        B_offset += k * n;
        C_offset += m * n;
        m_offset += m;
    }
Israt Nisa's avatar
Israt Nisa committed
263
264
}

265
template <int XPU, typename IdType, typename DType>
266
267
268
269
void SegmentMMBackwardB(const NDArray A,
                        const NDArray dC,
                        NDArray dB,
                        const NDArray seglen) {
270
271
272
273
274
275
276
277
278
279
    auto device = runtime::DeviceAPI::Get(A->ctx);
    cudaStream_t stream = runtime::getCurrentCUDAStream();
    const DType *A_data = A.Ptr<DType>();
    const DType *dC_data = dC.Ptr<DType>();
    const IdType* seglen_data = seglen.Ptr<IdType>();
    DType *dB_data = dB.Ptr<DType>();
    int64_t A_offset = 0, dC_offset = 0, dB_offset = 0;
    int64_t m, n, k;
    int64_t num_rel = seglen.NumElements();
    DType alpha = 1., beta = 1.;
280

281
282
283
284
    auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
    if (!thr_entry->cublas_handle)
        CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
    CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
285

286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    IdType k_offset = 0;
    for (IdType etype = 0; etype < num_rel; ++etype) {
        m = dC->shape[1];
        n = A->shape[1];
        k = seglen_data[etype];
        CHECK_LE(k_offset + k, A->shape[0]) << "Segement index out of bound of A->shape[0].";
        int lddC = m, ldA = n, lddB = m;
        cublasOperation_t trans_dC = CUBLAS_OP_N;
        cublasOperation_t trans_A = CUBLAS_OP_T;
        CUBLAS_CALL(cublasGemm<DType>(
            thr_entry->cublas_handle,
            trans_dC,
            trans_A,
            m, n, k,
            &alpha,
            dC_data + dC_offset, lddC,
            A_data + A_offset, ldA,
            &beta,
            dB_data + dB_offset, lddB));
        dC_offset += m * k;
        A_offset += n * k;
        dB_offset += m * n;
        k_offset += k;
    }
Israt Nisa's avatar
Israt Nisa committed
310
311
}

312
/**
313
 * @brief Implementation of Gather_mm operator. The input matrix A is
Israt Nisa's avatar
Israt Nisa committed
314
 *        expected to be sorted according to relation type.
315
316
317
318
319
 * @param A The input dense matrix of dimension m x k
 * @param B The input dense matrix of dimension k x n
 * @param C The output dense matrix of dimension m x n
 * @param idx_a The input vector to gather left hand operand on
 * @param idx_b The input vector to gather right hand operand on
Israt Nisa's avatar
Israt Nisa committed
320
 */
321

322
template <int XPU, typename IdType, typename DType>
323
324
325
326
327
void GatherMM(const NDArray A,
              const NDArray B,
              NDArray C,
              const NDArray idx_a,
              const NDArray idx_b) {
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
    auto device = runtime::DeviceAPI::Get(A->ctx);
    cudaStream_t stream = runtime::getCurrentCUDAStream();
    int64_t out_len = B->shape[2];  // cols of B
    int64_t in_len = A->shape[1];  // cols of A
    const int64_t tot_num_rows = A->shape[0];
    const int ntx = 128;
    const int warp_size = 32;
    const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
    const dim3 nblks(nbx);
    const dim3 nthrs(ntx);
    CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>),
        nblks, nthrs, 0, stream,
        A.Ptr<DType>(),
        B.Ptr<DType>(),
        C.Ptr<DType>(),
        idx_a.Ptr<IdType>(),
        idx_b.Ptr<IdType>(),
        nullptr,
        tot_num_rows, in_len, out_len);
Israt Nisa's avatar
Israt Nisa committed
347
348
}

349
/**
350
 * @brief Implementation of Gather_mm operator. The input matrix A is
Israt Nisa's avatar
Israt Nisa committed
351
 *        expected to be sorted according to relation type.
352
353
354
355
356
357
358
359
360
 * @param A The input dense matrix of dimension m x k
 * @param B The input dense matrix of dimension k x n
 * @param C The output dense matrix of dimension m x n
 * @param idx_a The input vector to gather left hand operand on
 * @param idx_b The input vector to gather right hand operand on
 * @param idx_c The input vector to gather output operand on
 * @param num_rel The number of idx types in idx_b
 * @param a_trans Matrix A to be transposed
 * @param b_trans Matrix B to be transposed
Israt Nisa's avatar
Israt Nisa committed
361
 */
362
template <int XPU, typename IdType, typename DType>
363
364
365
366
367
368
void GatherMMScatter(const NDArray A,
                     const NDArray B,
                     NDArray C,
                     const NDArray idx_a,
                     const NDArray idx_b,
                     const NDArray idx_c) {
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    auto device = runtime::DeviceAPI::Get(A->ctx);
    cudaStream_t stream = runtime::getCurrentCUDAStream();
    const IdType *idx_c_data = idx_c.Ptr<IdType>();
    int64_t out_len = (B->ndim == 2)? B->shape[1] : B->shape[2];  // cols of B
    int64_t in_len = A->shape[1];  // cols of A
    int64_t tot_num_rows = A->shape[0];
    const int ntx = 128;
    const int warp_size = 32;
    const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
    const dim3 nblks(nbx);
    const dim3 nthrs(ntx);
    if (B->ndim == 3) {
        CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>),
            nblks, nthrs, 0, stream,
            A.Ptr<DType>(),
            B.Ptr<DType>(),
            C.Ptr<DType>(),
            idx_a.Ptr<IdType>(),
            idx_b.Ptr<IdType>(),
            idx_c.Ptr<IdType>(),
            tot_num_rows, in_len, out_len);
    } else {
        // Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i]
        // This kernel accesses rows of A in a transposed way w/o explicitly converting A
        CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel2<IdType, DType>),
            nblks, nthrs, 0, stream,
            A.Ptr<DType>(),
            B.Ptr<DType>(),
            C.Ptr<DType>(),
            idx_a.Ptr<IdType>(),
            idx_b.Ptr<IdType>(),
            idx_c.Ptr<IdType>(),
            tot_num_rows, in_len, out_len);
    }
Israt Nisa's avatar
Israt Nisa committed
403
404
}

405
406
407
408
template void GatherMM<kDGLCUDA, int32_t, __half>(
    const NDArray A, const NDArray B, NDArray C,
    const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, __half>(
Israt Nisa's avatar
Israt Nisa committed
409
    const NDArray A, const NDArray B, NDArray C,
410
    const NDArray idx_a, const NDArray idx_b);
411
412
#if BF16_ENABLED
template void GatherMM<kDGLCUDA, int32_t, __nv_bfloat16>(
Israt Nisa's avatar
Israt Nisa committed
413
    const NDArray A, const NDArray B, NDArray C,
414
    const NDArray idx_a, const NDArray idx_b);
415
template void GatherMM<kDGLCUDA, int64_t, __nv_bfloat16>(
Israt Nisa's avatar
Israt Nisa committed
416
    const NDArray A, const NDArray B, NDArray C,
417
    const NDArray idx_a, const NDArray idx_b);
418
419
#endif  // BF16_ENABLED
template void GatherMM<kDGLCUDA, int32_t, float>(
Israt Nisa's avatar
Israt Nisa committed
420
    const NDArray A, const NDArray B, NDArray C,
421
    const NDArray idx_a, const NDArray idx_b);
422
template void GatherMM<kDGLCUDA, int64_t, float>(
Israt Nisa's avatar
Israt Nisa committed
423
    const NDArray A, const NDArray B, NDArray C,
424
    const NDArray idx_a, const NDArray idx_b);
425
426
427
428
template void GatherMM<kDGLCUDA, int32_t, double>(
    const NDArray A, const NDArray B, NDArray C,
    const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, double>(
Israt Nisa's avatar
Israt Nisa committed
429
    const NDArray A, const NDArray B, NDArray C,
430
    const NDArray idx_a, const NDArray idx_b);
Israt Nisa's avatar
Israt Nisa committed
431

432
433
434
435
436
437
438
439
template void GatherMMScatter<kDGLCUDA, int32_t, __half>(
    const NDArray A, const NDArray B, NDArray C,
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, __half>(
    const NDArray A, const NDArray B, NDArray C,
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
#if BF16_ENABLED
template void GatherMMScatter<kDGLCUDA, int32_t, __nv_bfloat16>(
Israt Nisa's avatar
Israt Nisa committed
440
    const NDArray A, const NDArray B, NDArray C,
441
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
442
template void GatherMMScatter<kDGLCUDA, int64_t, __nv_bfloat16>(
Israt Nisa's avatar
Israt Nisa committed
443
    const NDArray A, const NDArray B, NDArray C,
444
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
445
446
#endif  // BF16_ENABLED
template void GatherMMScatter<kDGLCUDA, int32_t, float>(
Israt Nisa's avatar
Israt Nisa committed
447
    const NDArray A, const NDArray B, NDArray C,
448
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
449
template void GatherMMScatter<kDGLCUDA, int64_t, float>(
Israt Nisa's avatar
Israt Nisa committed
450
    const NDArray A, const NDArray B, NDArray C,
451
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
452
template void GatherMMScatter<kDGLCUDA, int32_t, double>(
Israt Nisa's avatar
Israt Nisa committed
453
    const NDArray A, const NDArray B, NDArray C,
454
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
455
template void GatherMMScatter<kDGLCUDA, int64_t, double>(
Israt Nisa's avatar
Israt Nisa committed
456
    const NDArray A, const NDArray B, NDArray C,
457
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
Israt Nisa's avatar
Israt Nisa committed
458

459
template void SegmentMM<kDGLCUDA, int32_t, __half>(
Israt Nisa's avatar
Israt Nisa committed
460
461
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
462
template void SegmentMM<kDGLCUDA, int64_t, __half>(
Israt Nisa's avatar
Israt Nisa committed
463
464
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
465
466
#if BF16_ENABLED
template void SegmentMM<kDGLCUDA, int32_t, __nv_bfloat16>(
Israt Nisa's avatar
Israt Nisa committed
467
468
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
469
template void SegmentMM<kDGLCUDA, int64_t, __nv_bfloat16>(
Israt Nisa's avatar
Israt Nisa committed
470
471
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
472
473
#endif  // BF16_ENABLED
template void SegmentMM<kDGLCUDA, int32_t, float>(
Israt Nisa's avatar
Israt Nisa committed
474
475
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
476
477
478
479
480
481
482
template void SegmentMM<kDGLCUDA, int64_t, float>(
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int32_t, double>(
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, double>(
Israt Nisa's avatar
Israt Nisa committed
483
484
485
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);

486
487
488
489
490
491
template void SegmentMMBackwardB<kDGLCUDA, int32_t, __half>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCUDA, int64_t, __half>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
#if BF16_ENABLED
template void SegmentMMBackwardB<kDGLCUDA, int32_t, __nv_bfloat16>(
492
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
493
template void SegmentMMBackwardB<kDGLCUDA, int64_t, __nv_bfloat16>(
494
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
495
496
#endif  // BF16_ENABLED
template void SegmentMMBackwardB<kDGLCUDA, int32_t, float>(
497
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
498
template void SegmentMMBackwardB<kDGLCUDA, int64_t, float>(
499
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
500
template void SegmentMMBackwardB<kDGLCUDA, int32_t, double>(
501
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
502
template void SegmentMMBackwardB<kDGLCUDA, int64_t, double>(
503
504
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);

Israt Nisa's avatar
Israt Nisa committed
505
506
}  // namespace aten
}  // namespace dgl