// !!! This is a file automatically generated by hipify!!! #include "hip/hip_runtime.h" /** * Copyright (c) 2020 by Contributors * @file array/cuda/gather_mm.cu * @brief GatherMM C APIs and definitions. */ #include #include // std::swap #include "atomic.cuh" #include "functor.cuh" #include "utils.h" namespace dgl { using namespace cuda; namespace aten { namespace { /** @brief Call cuBLAS GEMM API for dense matmul operation for float and double. */ template hipblasStatus_t cublasGemm( hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, int m, int n, int k, const DType* alpha, const DType* A, int lda, const DType* B, int ldb, const DType* beta, DType* C, int ldc) { LOG(INFO) << "Not supported dtype"; return HIPBLAS_STATUS_EXECUTION_FAILED; } template <> hipblasStatus_t cublasGemm<__half>( hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* A, int lda, const __half* B, int ldb, const __half* beta, __half* C, int ldc) { return hipblasHgemm( handle, transa, transb, m, n, k, (hipblasHalf*)alpha, (hipblasHalf*)A, lda, (hipblasHalf*)B, ldb, (hipblasHalf*)beta, (hipblasHalf*)C, ldc); } // template <> // hipblasStatus_t cublasGemm<__half>( // hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, // int m, int n, int k, const __half* alpha, const __half* A, int lda, // const __half* B, int ldb, const __half* beta, __half* C, int ldc) { // return hipblasHgemm( // handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); // } #if BF16_ENABLED template <> hipblasStatus_t cublasGemm<__hip_bfloat16>( hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, int m, int n, int k, const __hip_bfloat16* alpha, const __hip_bfloat16* A, int lda, const __hip_bfloat16* B, int ldb, const __hip_bfloat16* beta, __hip_bfloat16* C, int ldc) { float alpha_float = __bfloat162float(*alpha); float beta_float = __bfloat162float(*beta); return hipblasGemmEx( handle, transa, transb, m, n, k, &alpha_float, A, HIPBLAS_R_16B, lda, B, HIPBLAS_R_16B, ldb, &beta_float, C, HIPBLAS_R_16B, ldc, HIPBLAS_R_32F, HIPBLAS_GEMM_DEFAULT); } #endif // BF16_ENABLED template <> hipblasStatus_t cublasGemm( hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, int m, int n, int k, const float* alpha, const float* A, int lda, const float* B, int ldb, const float* beta, float* C, int ldc) { return hipblasSgemm( handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } template <> hipblasStatus_t cublasGemm( hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, int m, int n, int k, const double* alpha, const double* A, int lda, const double* B, int ldb, const double* beta, double* C, int ldc) { return hipblasDgemm( handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } } // namespace namespace cuda { /** * @note Each row of A multiplies a segment of matrix of B of dimension in_len * * outlen. One warp is assigned to process one row of A. Each WARP sequentially * multiplies one element of A and a row of B to compute partial result of the * output. A is loaded in shared memory in a coalesced way. Output matrix is * loaded in registers. B should get benefit from L2 cache. */ template __global__ void GatherMMScatterKernel( const DType* __restrict__ A, const DType* __restrict__ B, DType* __restrict__ C, const Idx* __restrict__ idx_a, const Idx* __restrict__ idx_b, const Idx* __restrict__ idx_c, const int64_t num_rows, const int64_t in_len, const int64_t out_len) { unsigned int tId = threadIdx.x; unsigned int laneId = tId & 31; unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x); unsigned int warpId = gId >> 5; unsigned int row = warpId; if (row < num_rows) { const unsigned int local_row = row & 3; // hardcoded for TB size 128 (4 warps) const Idx cur_rowA = (idx_a) ? idx_a[row] : row; const Idx cur_rowB = (idx_b) ? idx_b[row] : row; const Idx cur_rowC = (idx_c) ? idx_c[row] : row; const Idx B_offset = cur_rowB * in_len * out_len; const int sh_a_tile = 64; __shared__ DType sh_A[4 * sh_a_tile]; int a_tile = sh_a_tile; for (unsigned int k_start = 0; k_start < in_len; k_start += 64) { if ((in_len - k_start) < a_tile) a_tile = in_len - k_start; // Load A in shared mem in a coalesced way for (unsigned int l = laneId; l < a_tile; l += 32) sh_A[local_row * sh_a_tile + l] = A[cur_rowA * in_len + (k_start + l)]; // __syncwarp(); for (unsigned int outloop = 0; outloop < out_len; outloop += 32) { DType out_reg = static_cast(0.0f); // thread private const unsigned int l = laneId; if (l < out_len) { // iterate over elements of a row of A for (unsigned int i = 0; i < a_tile; i++) { const DType a_val = sh_A[local_row * sh_a_tile + i]; // iterate over elements of a row of B in parallel out_reg += a_val * B[B_offset + ((i + k_start) * out_len + (outloop + l))]; } if (idx_c) { AtomicAdd(C + cur_rowC * out_len + (outloop + l), out_reg); } else { C[cur_rowC * out_len + (outloop + l)] += out_reg; } } } } } } /** * @note Output matrix is accumulated via atomic operations. Rest of the * strategies are similar to GatherMMKernel. One warp is assigned to process one * row of A. Each WARP sequentially multiplies one element of A and a row of B * to compute partial result of the output. A is loaded in shared memory in a * coalesced way. B should get benefit from L2 cache. */ template __global__ void GatherMMScatterKernel2( const DType* __restrict__ A, const DType* __restrict__ B, DType* __restrict__ C, const Idx* __restrict__ idx_a, const Idx* __restrict__ idx_b, const Idx* __restrict__ idx_c, const int64_t num_rows, const int64_t in_len, const int64_t out_len) { unsigned int tId = threadIdx.x; unsigned int laneId = tId & 31; unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x); unsigned int warpId = gId >> 5; unsigned int row = warpId; if (row < num_rows) { const unsigned int local_row = row & 3; // hardcoded for TB size 128 (4 warps) const Idx row_a = (idx_a) ? idx_a[row] : row; const Idx row_b = (idx_b) ? idx_b[row] : row; const Idx row_c = (idx_c) ? idx_c[row] : row; const Idx C_offset = row_c * in_len * out_len; const int sh_a_tile = 64; __shared__ DType sh_A[4 * sh_a_tile]; int a_tile = sh_a_tile; for (unsigned int k_start = 0; k_start < in_len; k_start += 64) { if ((in_len - k_start) < a_tile) a_tile = in_len - k_start; /* Load A in shared mem in a coalesced way */ for (unsigned int l = laneId; l < a_tile; l += 32) sh_A[local_row * sh_a_tile + l] = A[row_a * in_len + (k_start + l)]; // __syncwarp(); for (unsigned int outloop = 0; outloop < out_len; outloop += 32) { DType out_reg = static_cast(0.0f); // thread private const unsigned int l = laneId; if (l < out_len) { const DType b_val = B[row_b * out_len + (outloop + l)]; /* iterate over elements of a row of A */ for (unsigned int i = 0; i < a_tile; i++) { const DType a_val = sh_A[local_row * sh_a_tile + i]; const Idx C_idx = C_offset + ((i + k_start) * out_len + (outloop + l)); AtomicAdd(C + C_idx, a_val * b_val); } } } } } } } // namespace cuda /** * @brief Implementation of Gather_mm operator. The input matrix A is * expected to be sorted according to relation type. * @param A The input dense matrix of dimension m x k * @param B The input dense matrix of dimension k x n * @param C The output dense matrix of dimension m x n * @param seglen_A The input vector of size R. Each element * is the length of segments of input ``A`` * @param a_trans Matrix A to be transposed * @param b_trans Matrix B to be transposed */ template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans) { auto device = runtime::DeviceAPI::Get(A->ctx); hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA(); const DType* A_data = A.Ptr(); const DType* B_data = B.Ptr(); const IdType* seglen_A_data = seglen_A.Ptr(); DType* C_data = C.Ptr(); int64_t A_offset = 0, B_offset = 0, C_offset = 0; int64_t m, n, k; int64_t num_rel = seglen_A.NumElements(); DType alpha = 1., beta = 0.; auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); if (!thr_entry->cublas_handle) CUBLAS_CALL(hipblasCreate(&(thr_entry->cublas_handle))); CUBLAS_CALL(hipblasSetStream(thr_entry->cublas_handle, stream)); IdType m_offset = 0; for (IdType etype = 0; etype < num_rel; ++etype) { m = seglen_A_data[etype]; // rows of A CHECK_LE(m_offset + m, A->shape[0]) << "Segment index out of bound of A->shape[0]."; n = B->shape[2]; // cols of B k = B->shape[1]; // cols of A == rows of B int ldb = n, lda = k, ldc = n; hipblasOperation_t transB = HIPBLAS_OP_N; hipblasOperation_t transA = HIPBLAS_OP_N; if (b_trans) { transB = HIPBLAS_OP_T; ldb = n, lda = n, ldc = k; std::swap(n, k); } CUBLAS_CALL(cublasGemm( thr_entry->cublas_handle, transB, transA, n, m, k, &alpha, B_data + B_offset, ldb, A_data + A_offset, lda, &beta, C_data + C_offset, ldc)); A_offset += m * k; B_offset += k * n; C_offset += m * n; m_offset += m; } } template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) { auto device = runtime::DeviceAPI::Get(A->ctx); hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA(); const DType* A_data = A.Ptr(); const DType* dC_data = dC.Ptr(); const IdType* seglen_data = seglen.Ptr(); DType* dB_data = dB.Ptr(); int64_t A_offset = 0, dC_offset = 0, dB_offset = 0; int64_t m, n, k; int64_t num_rel = seglen.NumElements(); DType alpha = 1., beta = 0.; auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); if (!thr_entry->cublas_handle) CUBLAS_CALL(hipblasCreate(&(thr_entry->cublas_handle))); CUBLAS_CALL(hipblasSetStream(thr_entry->cublas_handle, stream)); IdType k_offset = 0; for (IdType etype = 0; etype < num_rel; ++etype) { m = dC->shape[1]; n = A->shape[1]; k = seglen_data[etype]; CHECK_LE(k_offset + k, A->shape[0]) << "Segement index out of bound of A->shape[0]."; int lddC = m, ldA = n, lddB = m; hipblasOperation_t trans_dC = HIPBLAS_OP_N; hipblasOperation_t trans_A = HIPBLAS_OP_T; CUBLAS_CALL(cublasGemm( thr_entry->cublas_handle, trans_dC, trans_A, m, n, k, &alpha, dC_data + dC_offset, lddC, A_data + A_offset, ldA, &beta, dB_data + dB_offset, lddB)); dC_offset += m * k; A_offset += n * k; dB_offset += m * n; k_offset += k; } } /** * @brief Implementation of Gather_mm operator. The input matrix A is * expected to be sorted according to relation type. * @param A The input dense matrix of dimension m x k * @param B The input dense matrix of dimension k x n * @param C The output dense matrix of dimension m x n * @param idx_a The input vector to gather left hand operand on * @param idx_b The input vector to gather right hand operand on */ template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b) { auto device = runtime::DeviceAPI::Get(A->ctx); hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA(); int64_t out_len = B->shape[2]; // cols of B int64_t in_len = A->shape[1]; // cols of A const int64_t tot_num_rows = A->shape[0]; const int ntx = 128; const int warp_size = 32; const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx); const dim3 nblks(nbx); const dim3 nthrs(ntx); CUDA_KERNEL_CALL( (cuda::GatherMMScatterKernel), nblks, nthrs, 0, stream, A.Ptr(), B.Ptr(), C.Ptr(), idx_a.Ptr(), idx_b.Ptr(), nullptr, tot_num_rows, in_len, out_len); } /** * @brief Implementation of Gather_mm operator. The input matrix A is * expected to be sorted according to relation type. * @param A The input dense matrix of dimension m x k * @param B The input dense matrix of dimension k x n * @param C The output dense matrix of dimension m x n * @param idx_a The input vector to gather left hand operand on * @param idx_b The input vector to gather right hand operand on * @param idx_c The input vector to gather output operand on * @param num_rel The number of idx types in idx_b * @param a_trans Matrix A to be transposed * @param b_trans Matrix B to be transposed */ template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c) { auto device = runtime::DeviceAPI::Get(A->ctx); hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA(); const IdType* idx_c_data = idx_c.Ptr(); int64_t out_len = (B->ndim == 2) ? B->shape[1] : B->shape[2]; // cols of B int64_t in_len = A->shape[1]; // cols of A int64_t tot_num_rows = A->shape[0]; const int ntx = 128; const int warp_size = 32; const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx); const dim3 nblks(nbx); const dim3 nthrs(ntx); if (B->ndim == 3) { CUDA_KERNEL_CALL( (cuda::GatherMMScatterKernel), nblks, nthrs, 0, stream, A.Ptr(), B.Ptr(), C.Ptr(), idx_a.Ptr(), idx_b.Ptr(), idx_c.Ptr(), tot_num_rows, in_len, out_len); } else { // Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i] // This kernel accesses rows of A in a transposed way w/o explicitly // converting A CUDA_KERNEL_CALL( (cuda::GatherMMScatterKernel2), nblks, nthrs, 0, stream, A.Ptr(), B.Ptr(), C.Ptr(), idx_a.Ptr(), idx_b.Ptr(), idx_c.Ptr(), tot_num_rows, in_len, out_len); } } template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); #if BF16_ENABLED template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); #endif // BF16_ENABLED template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); #if BF16_ENABLED template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); #endif // BF16_ENABLED template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); #if BF16_ENABLED template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); #endif // BF16_ENABLED template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); #if BF16_ENABLED template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); #endif // BF16_ENABLED template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); } // namespace aten } // namespace dgl