gather_mm.cu 18.2 KB
Newer Older
Israt Nisa's avatar
Israt Nisa committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/gather_mm.cu
 * \brief GatherMM C APIs and definitions.
 */
#include <dgl/array.h>
#include <algorithm>  // std::swap
#include "./utils.h"
#include "./functor.cuh"
#include "./atomic.cuh"

namespace dgl {
using namespace cuda;
namespace aten {

namespace {

/*! \brief Call cuBLAS GEMM API for dense matmul operation for float and double. */
template <typename DType>
cublasStatus_t cublasGemm(cublasHandle_t handle, cublasOperation_t transa,
    cublasOperation_t transb, int m, int n, int k,
    const DType* alpha, const DType* A, int lda,
    const DType* B, int ldb, const DType* beta,
    DType* C, int ldc) {
  LOG(INFO) << "Not supported dtype";
  return CUBLAS_STATUS_EXECUTION_FAILED;
}

template <>
cublasStatus_t cublasGemm<float>(cublasHandle_t handle, cublasOperation_t transa,
    cublasOperation_t transb, int m, int n, int k,
    const float* alpha, const float* A, int lda,
    const float* B, int ldb, const float* beta,
    float* C, int ldc) {
  return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda,
      B, ldb, beta, C, ldc);
}

template <>
cublasStatus_t cublasGemm<double>(cublasHandle_t handle, cublasOperation_t transa,
    cublasOperation_t transb, int m, int n, int k,
    const double* alpha, const double* A, int lda,
    const double* B, int ldb, const double* beta,
    double* C, int ldc) {
  return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda,
      B, ldb, beta, C, ldc);
}

}  // namespace

namespace cuda {

/* \Note Each row of A multiplies a segment of matrix of B of dimension in_len * outlen.
  One warp is assigned to process one row of A. Each WARP sequentially multiplies
  one element of A and a row of B to compute partial result of the output. A
  is loaded in shared memory in a coalesced way. Output matrix is loaded in
  registers. B should get benefit from L2 cache.
*/
template <typename Idx, typename DType>
60
__global__ void GatherMMScatterKernel(
Israt Nisa's avatar
Israt Nisa committed
61
62
63
64
65
    const DType* __restrict__ A,
    const DType* __restrict__ B,
    DType* __restrict__ C,
    const Idx* __restrict__ idx_a,
    const Idx* __restrict__ idx_b,
66
67
68
69
70
    const Idx* __restrict__ idx_c,
    const int64_t num_rows,
    const int64_t in_len,
    const int64_t out_len) {

Israt Nisa's avatar
Israt Nisa committed
71
72
73
74
75
76
    unsigned int tId = threadIdx.x;
    unsigned int laneId = tId & 31;
    unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
    unsigned int warpId = gId >> 5;
    unsigned int row = warpId;
    if (row < num_rows) {
77
78
79
80
81
        const unsigned int local_row = row & 3;  // hardcoded for TB size 128 (4 warps)
        const Idx cur_rowA = (idx_a) ? idx_a[row] : row;
        const Idx cur_rowB = (idx_b) ? idx_b[row] : row;
        const Idx cur_rowC = (idx_c) ? idx_c[row] : row;
        const Idx B_offset = cur_rowB * in_len * out_len;
Israt Nisa's avatar
Israt Nisa committed
82
83
84
85
86
        const int sh_a_tile = 64;
        __shared__ DType sh_A[4 * sh_a_tile];
        int a_tile = sh_a_tile;
        for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
            if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
87
            // Load A in shared mem in a coalesced way
Israt Nisa's avatar
Israt Nisa committed
88
89
90
91
92
93
94
95
            for (unsigned int l = laneId; l < a_tile; l += 32)
                sh_A[local_row * sh_a_tile + l] = A[cur_rowA * in_len + (k_start + l)];
            __syncwarp();

            for (unsigned int outloop = 0; outloop < out_len; outloop +=32) {
                DType out_reg = 0;  // thread private
                const unsigned int l = laneId;
                if (l < out_len) {
96
                    // iterate over elements of a row of A
Israt Nisa's avatar
Israt Nisa committed
97
98
                    for (unsigned int i = 0; i < a_tile; i++) {
                        const DType a_val =  sh_A[local_row * sh_a_tile + i];
99
                        // iterate over elements of a row of B in parallel
Israt Nisa's avatar
Israt Nisa committed
100
101
                        out_reg += a_val * B[B_offset + ((i + k_start) * out_len + (outloop + l))];
                    }
102
103
104
105
106
                    if (idx_c) {
                      AtomicAdd(C + cur_rowC * out_len + (outloop + l), out_reg);
                    } else {
                      C[cur_rowC * out_len + (outloop + l)] += out_reg;
                    }
Israt Nisa's avatar
Israt Nisa committed
107
108
109
110
111
112
                }
            }
        }
    }
}

113

Israt Nisa's avatar
Israt Nisa committed
114
/* \Note Output matrix is accumulated via atomic operations. Rest of the strategies
115
  are similar to GatherMMKernel. One warp is assigned to process one row of A. Each
Israt Nisa's avatar
Israt Nisa committed
116
117
118
119
120
  WARP sequentially multiplies one element of A and a row of B to compute partial
  result of the output. A is loaded in shared memory in a coalesced way. B should
  get benefit from L2 cache.
*/
template <typename Idx, typename DType>
121
__global__ void GatherMMScatterKernel2(
Israt Nisa's avatar
Israt Nisa committed
122
123
124
125
126
127
    const DType* __restrict__ A,
    const DType* __restrict__ B,
    DType* __restrict__ C,
    const Idx* __restrict__ idx_a,
    const Idx* __restrict__ idx_b,
    const Idx* __restrict__ idx_c,
128
129
130
131
    const int64_t num_rows,
    const int64_t in_len,
    const int64_t out_len) {

Israt Nisa's avatar
Israt Nisa committed
132
133
134
135
136
137
    unsigned int tId = threadIdx.x;
    unsigned int laneId = tId & 31;
    unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
    unsigned int warpId = gId >> 5;
    unsigned int row = warpId;
    if (row < num_rows) {
138
139
140
141
142
        const unsigned int local_row = row & 3;  // hardcoded for TB size 128 (4 warps)
        const Idx row_a = (idx_a) ? idx_a[row] : row;
        const Idx row_b = (idx_b) ? idx_b[row] : row;
        const Idx row_c = (idx_c) ? idx_c[row] : row;
        const Idx C_offset = row_c * in_len * out_len;
Israt Nisa's avatar
Israt Nisa committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        const int sh_a_tile = 64;
        __shared__ DType sh_A[4 * sh_a_tile];
        int a_tile = sh_a_tile;
        for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
            if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
            /* Load A in shared mem in a coalesced way */
            for (unsigned int l = laneId; l < a_tile; l += 32)
                sh_A[local_row * sh_a_tile + l] = A[row_a * in_len + (k_start + l)];
            __syncwarp();

            for (unsigned int outloop = 0; outloop < out_len; outloop +=32) {
                DType out_reg = 0;  // thread private
                const unsigned int l = laneId;
                if (l < out_len) {
                    const DType b_val = B[row_b * out_len + (outloop + l)];
                    /* iterate over elements of a row of A */
                    for (unsigned int i = 0; i < a_tile; i++) {
                        const DType a_val = sh_A[local_row * sh_a_tile + i];
                        const Idx C_idx = C_offset + ((i + k_start) * out_len + (outloop + l));
162
                        AtomicAdd(C + C_idx, a_val * b_val);
Israt Nisa's avatar
Israt Nisa committed
163
164
165
166
167
168
169
170
171
                    }
                }
            }
        }
    }
}

}  // namespace cuda

172
173
174
175
176
177
178
179
180
181
/*!
 * \brief Implementation of Gather_mm operator. The input matrix A is
 *        expected to be sorted according to relation type.
 * \param A The input dense matrix of dimension m x k
 * \param B The input dense matrix of dimension k x n
 * \param C The output dense matrix of dimension m x n
 * \param seglen_A The input vector of size R. Each element
 *        is the length of segments of input ``A``
 * \param a_trans Matrix A to be transposed
 * \param b_trans Matrix B to be transposed
Israt Nisa's avatar
Israt Nisa committed
182
183
 */
template <int XPU, typename IdType, int bits>
184
185
186
187
188
void SegmentMM(const NDArray A,
               const NDArray B,
               NDArray C,
               const NDArray seglen_A,
               bool a_trans, bool b_trans) {
Israt Nisa's avatar
Israt Nisa committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    SWITCH_BITS(bits, DType, {
        auto device = runtime::DeviceAPI::Get(A->ctx);
        const DType *A_data = A.Ptr<DType>();
        const DType *B_data = B.Ptr<DType>();
        const IdType* seglen_A_data = seglen_A.Ptr<IdType>();
        DType *C_data = C.Ptr<DType>();
        int64_t A_offset = 0, B_offset = 0, C_offset = 0;
        int64_t m, n, k;
        int64_t num_rel = seglen_A.NumElements();
        DType alpha = 1., beta = 0.;

        auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
        if (!thr_entry->cublas_handle)
            CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
        CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle,
            thr_entry->stream));

206
207
        IdType m_offset = 0;
        for (IdType etype = 0; etype < num_rel; ++etype) {
Israt Nisa's avatar
Israt Nisa committed
208
            m = seglen_A_data[etype];  // rows of A
209
210
211
            CHECK_LE(m_offset + m, A->shape[0]) << "Segment index out of bound of A->shape[0].";
            n = B->shape[2];  // cols of B
            k = B->shape[1];  // cols of A == rows of B
Israt Nisa's avatar
Israt Nisa committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
            int ldb = n, lda = k, ldc = n;
            cublasOperation_t transB = CUBLAS_OP_N;
            cublasOperation_t transA = CUBLAS_OP_N;
            if (b_trans) {
                transB = CUBLAS_OP_T;
                ldb = n, lda = n, ldc = k;
                std::swap(n, k);
            }
            CUBLAS_CALL(cublasGemm<DType>(
                thr_entry->cublas_handle,
                transB,
                transA,
                n, m, k,
                &alpha,
                B_data + B_offset, ldb,
                A_data + A_offset, lda,
                &beta,
                C_data + C_offset, ldc));
            A_offset += m * k;
            B_offset += k * n;
            C_offset += m * n;
233
            m_offset += m;
Israt Nisa's avatar
Israt Nisa committed
234
235
236
237
238
        }
    });
}

template <int XPU, typename IdType, int bits>
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
void SegmentMMBackwardB(const NDArray A,
                        const NDArray dC,
                        NDArray dB,
                        const NDArray seglen) {
    SWITCH_BITS(bits, DType, {
        auto device = runtime::DeviceAPI::Get(A->ctx);
        const DType *A_data = A.Ptr<DType>();
        const DType *dC_data = dC.Ptr<DType>();
        const IdType* seglen_data = seglen.Ptr<IdType>();
        DType *dB_data = dB.Ptr<DType>();
        int64_t A_offset = 0, dC_offset = 0, dB_offset = 0;
        int64_t m, n, k;
        int64_t num_rel = seglen.NumElements();
        DType alpha = 1., beta = 1.;

        auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
        if (!thr_entry->cublas_handle)
            CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
        CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle,
            thr_entry->stream));

        IdType k_offset = 0;
        for (IdType etype = 0; etype < num_rel; ++etype) {
            m = dC->shape[1];
            n = A->shape[1];
            k = seglen_data[etype];
            CHECK_LE(k_offset + k, A->shape[0]) << "Segement index out of bound of A->shape[0].";
            int lddC = m, ldA = n, lddB = m;
            cublasOperation_t trans_dC = CUBLAS_OP_N;
            cublasOperation_t trans_A = CUBLAS_OP_T;
            CUBLAS_CALL(cublasGemm<DType>(
                thr_entry->cublas_handle,
                trans_dC,
                trans_A,
                m, n, k,
                &alpha,
                dC_data + dC_offset, lddC,
                A_data + A_offset, ldA,
                &beta,
                dB_data + dB_offset, lddB));
            dC_offset += m * k;
            A_offset += n * k;
            dB_offset += m * n;
            k_offset += k;
        }
    });
Israt Nisa's avatar
Israt Nisa committed
285
286
287
288
289
290
291
292
293
294
295
}

/*!
 * \brief Implementation of Gather_mm operator. The input matrix A is
 *        expected to be sorted according to relation type.
 * \param A The input dense matrix of dimension m x k
 * \param B The input dense matrix of dimension k x n
 * \param C The output dense matrix of dimension m x n
 * \param idx_a The input vector to gather left hand operand on
 * \param idx_b The input vector to gather right hand operand on
 */
296

Israt Nisa's avatar
Israt Nisa committed
297
template <int XPU, typename IdType, int bits>
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
void GatherMM(const NDArray A,
              const NDArray B,
              NDArray C,
              const NDArray idx_a,
              const NDArray idx_b) {
  SWITCH_BITS(bits, DType, {
        auto device = runtime::DeviceAPI::Get(A->ctx);
        auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
        int64_t out_len = B->shape[2];  // cols of B
        int64_t in_len = A->shape[1];  // cols of A
        const int64_t tot_num_rows = A->shape[0];
        const int ntx = 128;
        const int warp_size = 32;
        const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
        const dim3 nblks(nbx);
        const dim3 nthrs(ntx);
        CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>),
            nblks, nthrs, 0, thr_entry->stream,
            A.Ptr<DType>(),
            B.Ptr<DType>(),
            C.Ptr<DType>(),
            idx_a.Ptr<IdType>(),
            idx_b.Ptr<IdType>(),
            nullptr,
            tot_num_rows, in_len, out_len);
    });
Israt Nisa's avatar
Israt Nisa committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
}

/*!
 * \brief Implementation of Gather_mm operator. The input matrix A is
 *        expected to be sorted according to relation type.
 * \param A The input dense matrix of dimension m x k
 * \param B The input dense matrix of dimension k x n
 * \param C The output dense matrix of dimension m x n
 * \param idx_a The input vector to gather left hand operand on
 * \param idx_b The input vector to gather right hand operand on
 * \param idx_c The input vector to gather output operand on
 * \param num_rel The number of idx types in idx_b
 * \param a_trans Matrix A to be transposed
 * \param b_trans Matrix B to be transposed
 */
template <int XPU, typename IdType, int bits>
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
void GatherMMScatter(const NDArray A,
                     const NDArray B,
                     NDArray C,
                     const NDArray idx_a,
                     const NDArray idx_b,
                     const NDArray idx_c) {
    SWITCH_BITS(bits, DType, {
        auto device = runtime::DeviceAPI::Get(A->ctx);
        auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
        const IdType *idx_c_data = idx_c.Ptr<IdType>();
        int64_t out_len = (B->ndim == 2)? B->shape[1] : B->shape[2];  // cols of B
        int64_t in_len = A->shape[1];  // cols of A
        int64_t tot_num_rows = A->shape[0];
        const int ntx = 128;
        const int warp_size = 32;
        const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
        const dim3 nblks(nbx);
        const dim3 nthrs(ntx);
        if (B->ndim == 3) {
          CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>),
              nblks, nthrs, 0, thr_entry->stream,
              A.Ptr<DType>(),
              B.Ptr<DType>(),
              C.Ptr<DType>(),
              idx_a.Ptr<IdType>(),
              idx_b.Ptr<IdType>(),
              idx_c.Ptr<IdType>(),
              tot_num_rows, in_len, out_len);
        } else {
          // Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i]
          // This kernel accesses rows of A in a transposed way w/o explicitly converting A
          CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel2<IdType, DType>),
              nblks, nthrs, 0, thr_entry->stream,
              A.Ptr<DType>(),
              B.Ptr<DType>(),
              C.Ptr<DType>(),
              idx_a.Ptr<IdType>(),
              idx_b.Ptr<IdType>(),
              idx_c.Ptr<IdType>(),
              tot_num_rows, in_len, out_len);
        }
    });
Israt Nisa's avatar
Israt Nisa committed
382
383
384
}


385
template void GatherMM<kDLGPU, int32_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
386
    const NDArray A, const NDArray B, NDArray C,
387
388
    const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
389
    const NDArray A, const NDArray B, NDArray C,
390
391
    const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int32_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
392
    const NDArray A, const NDArray B, NDArray C,
393
394
    const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
395
    const NDArray A, const NDArray B, NDArray C,
396
397
    const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int32_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
398
    const NDArray A, const NDArray B, NDArray C,
399
400
    const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
401
    const NDArray A, const NDArray B, NDArray C,
402
    const NDArray idx_a, const NDArray idx_b);
Israt Nisa's avatar
Israt Nisa committed
403

404
template void GatherMMScatter<kDLGPU, int32_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
405
    const NDArray A, const NDArray B, NDArray C,
406
407
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
408
    const NDArray A, const NDArray B, NDArray C,
409
410
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int32_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
411
    const NDArray A, const NDArray B, NDArray C,
412
413
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
414
    const NDArray A, const NDArray B, NDArray C,
415
416
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int32_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
417
    const NDArray A, const NDArray B, NDArray C,
418
419
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
420
    const NDArray A, const NDArray B, NDArray C,
421
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
Israt Nisa's avatar
Israt Nisa committed
422

423
template void SegmentMM<kDLGPU, int32_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
424
425
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
426
template void SegmentMM<kDLGPU, int64_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
427
428
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
429
template void SegmentMM<kDLGPU, int32_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
430
431
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
432
template void SegmentMM<kDLGPU, int64_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
433
434
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
435
template void SegmentMM<kDLGPU, int32_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
436
437
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
438
template void SegmentMM<kDLGPU, int64_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
439
440
441
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);

442
443
444
445
446
447
448
449
450
451
452
453
454
template void SegmentMMBackwardB<kDLGPU, int32_t, 16>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 16>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int32_t, 32>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 32>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int32_t, 64>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 64>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);

Israt Nisa's avatar
Israt Nisa committed
455
456
}  // namespace aten
}  // namespace dgl