/* * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include // TODO: Need to remove the dependency of the layer module. // e.g. refactor Weight class to some base module. #include "src/turbomind/layers/DenseWeight.h" #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/cublasAlgoMap.h" #include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/memory_utils.h" #ifndef CUDART_VERSION #error CUDART_VERSION Undefined! #endif // cublas default workspace size: 32MB. Let me make this as a Gemm property. #define WORKSPACE_SIZE 33554432 namespace turbomind { // A wrapper of cublas or cusparse matrix operator. // - GEMM_OP_N = CUBLAS_OP_N or CUSPARSE_OP_N // - GEMM_OP_T = CUBLAS_OP_T or CUSPARSE_OP_T enum GemmOp { GEMM_OP_N, GEMM_OP_T }; // A base class of the GEMM family. // In the current version Gemm is as a base class as well as an interface. class Gemm { public: Gemm() = delete; // Disable a default constructor /** * A Gemm class. * * NOTE: * A, B, C are assumed to have a row major layout, while a backend cuda libraries * assumes a column major layout. However, a family of Gemm has already handled * such discrepancy internally. Please use naively without a trick like switching * inputs A and B that aligns the matrix layout. * * Restriction: Supported in/out data or compute types: TYPE_FP16, TYPE_FP32. * * TODO: * Unify resource allocation/release from a singleton GPU resource managers. * Thus, allocator, stream can be replaced by a resource handler later. * E.g. Gemm(std::shared_ptr resource_manager), and * stream_ = resource_manager.getCudaStream(); * buffer = resource_manager.malloc(...); * * @param allocator Resource allocator. * @param stream A CUDA stream. * @param config_file A file path of a GEMM configuration. */ Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file = GEMM_CONFIG); Gemm(Gemm const& other) = delete; virtual ~Gemm(); virtual std::string toString(); /** * @brief Set GEMM compute type. * * @param compute_type The data type of accumulation type inside GEMM computation. * (Choices: TYPE_FP16, TYPE_FP32) * * @throw GemmNotSupportedException if a type is not TYPE_FP16 or TYPE_FP32. * @throw std::runtime_error if any exception inside CUDA. */ void setComputeType(DataType compute_type); /** * @brief Set matrix data types and compute precision. * * Supported data or compute types: TYPE_FP16, TYPE_FP32 * * @param a_type The data type of a matrix A. * @param b_type The data type of a matrix B. * @param c_type The data type of a matrix C. * @param compute_type The data type of accumulation type inside GEMM computation. * * @throw GemmNotSupportedException if a type is not TYPE_FP16 or TYPE_FP32. * @throw std::runtime_error if any exception inside CUDA. */ void setTypes(DataType a_type, DataType b_type, DataType c_type, DataType compute_type); /** * @brief Set matrix data and compute types by default values. * * Default configs: * - T=float : data type=TYPE_FP32, compute type=TYPE_FP32 * - T=half : data type=TYPE_FP16, compute type=TYPE_FP32 */ template void setDefaultTypes(); void loadGemmConfig(std::string config_file); void setAllocator(IAllocator* allocator); void setCudaStream(cudaStream_t& stream); // Th APIs below are to see how the interface will change // if it cooperates with Tensor. To enable it, we need to // update the Tensor class. For instance, data is need to // be of type (void*) rather than (const void*) to pass it // as the output C of gemm. // virtual void gemm(Tensor& C, // const GemmOp transa, // const GemmOp transb, // const Tensor& A, // const Tensor& B, // const float alpha = 1.0f, // const float beta = 0.0f); // // virtual void batchedMatmul(std::vector Carray, // const GemmOp transa, // const GemmOp transb, // const std::vector Aarray, // const std::vector Barray, // const float alpha = 1.0f, // const float beta = 0.0f); // // virtual void stridedBatchedGemm(Tensor& C, // const GemmOp transa, // const GemmOp transb, // const Tensor& A, // const Tensor& B, // const float alpha = 1.0f, // const float beta = 0.0f); // TODO: // This function cooperates with a Weight object to simply Gemm calls // inside layers, computing the following formula // output(C) = input(A) * weight_kernel(B) // where weight_kernel can be changed according to Gemm functions. // DenseWeight is of a template struct, not allowing override the method. // We temperally add an interface here for two cases float/half, // but to finialze this function, we need an interface of a weight class // which is not a template class. virtual void gemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, const void* input, const DenseWeight& weight, void* output, const float alpha = 1.0f, const float beta = 0.0f); virtual void gemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, const void* input, const DenseWeight& weight, void* output, const float alpha = 1.0f, const float beta = 0.0f); virtual void gemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, const void* A, const void* B, void* C, const float alpha = 1.0f, const float beta = 0.0f); virtual void gemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, const void* A, const size_t lda, const void* B, const size_t ldb, void* C, const size_t ldc, const float alpha = 1.0f, const float beta = 0.0f); /** * @brief Compute the matrix multiplication `C = \alpha * op(A) * op(B) + \beta * C`. * * @param transa A transpose operation of a matrix A (GEMM_OP_N or GEMM_OP_T). * @param transb A transpose operation of a matrix B (GEMM_OP_N or GEMM_OP_T). * @param m A number of rows of a matrix op(A) and C. * @param n A number of columns of a matrix op(B) or C. * @param k A number of columns of op(A) and rows of op(B). * @param A A device pointer of a matrix A of dimension (m x lda). * @param Atype A data type of A (TYPE_FP16 or TYPE_FP32) * @param lda A leading dimension of the matrix A. * @param B A device pointer of a matrix B of dimension (k x ldb). * @param Btype A data type of B (TYPE_FP16 or TYPE_FP32) * @param ldb A leading dimension of the matrix B. * @param C (Output) A device pointer of a matrix C of dimension (m x ldc). * @param Ctype A data type of C (TYPE_FP16 or TYPE_FP32) * @param ldc A leading dimension of the matrix C. * @param alpha A scale factor for A*B (default: 1.0f). * @param beta A scale factor for C (default: 0.0f). * * @throw GemmNotSupportedException if a type is not TYPE_FP16 or TYPE_FP32. * @throw std::runtime_error if any exception inside CUDA. */ virtual void gemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, const void* A, const DataType Atype, const size_t lda, const void* B, const DataType Btype, const size_t ldb, void* C, const DataType Ctype, const size_t ldc, const float alpha = 1.0f, const float beta = 0.0f); virtual void batchedGemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, const void* const* A, const void* const* B, void* const* C, const size_t batch_size, const float alpha = 1.0f, const float beta = 0.0f); virtual void batchedGemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, const void* const* A, const size_t lda, const void* const* B, const size_t ldb, void* const* C, const size_t ldc, const size_t batch_size, const float alpha = 1.0f, const float beta = 0.0f); /** * @brief Compute the matrix multiplication of batch of matrices As and Bs * * For input batch A[i]/B[i] and output batch C[i], i = 0, ..., batch_size - 1, * `C[i] = \alpha * op(A[i]) * op(B[i]) + \beta * C[i]`. * * @param transa A transpose operation of a matrix A (GEMM_OP_N or GEMM_OP_T). * @param transb A transpose operation of a matrix B (GEMM_OP_N or GEMM_OP_T). * @param m A number of rows of a matrix op(A) and C. * @param n A number of columns of a matrix op(B) or C. * @param k A number of columns of op(A) and rows of op(B). * @param A An array of device pointers of batch of input A matrices. * @param Atype A data type of A (TYPE_FP16 or TYPE_FP32) * @param lda A leading dimension of the matrix A. * @param B An array of device pointers of batch of input B matrices. * @param Btype A data type of B (TYPE_FP16 or TYPE_FP32) * @param ldb A leading dimension of the matrix B. * @param C (Output) An array of device pointers of batch of output C matrices. * @param Ctype A data type of C (TYPE_FP16 or TYPE_FP32) * @param ldc A leading dimension of the matrix C. * @param alpha A scale factor for A*B (default: 1.0f). * @param beta A scale factor for C (default: 0.0f). * * @throw GemmNotSupportedException if a type is not TYPE_FP16 or TYPE_FP32. * @throw std::runtime_error if any exception inside CUDA. */ virtual void batchedGemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, const void* const* A, const DataType Atype, const size_t lda, const void* const* B, const DataType Btype, const size_t ldb, void* const* C, const DataType Ctype, const size_t ldc, const size_t batch_size, const float alpha = 1.0f, const float beta = 0.0f); virtual void stridedBatchedGemm(GemmOp transa, GemmOp transb, const size_t m, const size_t n, const size_t k, const void* A, const void* B, void* C, const size_t batch_size, const float alpha = 1.0f, const float beta = 0.0f); virtual void stridedBatchedGemm(GemmOp transa, GemmOp transb, const size_t m, const size_t n, const size_t k, const void* A, const int64_t strideA, const void* B, const int64_t strideB, void* C, const int64_t strideC, const size_t batch_size, const float alpha = 1.0f, const float beta = 0.0f); virtual void stridedBatchedGemm(GemmOp transa, GemmOp transb, const size_t m, const size_t n, const size_t k, const void* A, const size_t lda, const int64_t strideA, const void* B, const size_t ldb, const int64_t strideB, void* C, const size_t ldc, const int64_t strideC, const size_t batch_size, const float alpha = 1.0f, const float beta = 0.0f); /** * @brief Compute the strided matrix multiplication of batch of matrices As and Bs * * For input batch A[i]/B[i] and output batch C[i], i = 0, ..., batch_size - 1, * `C[i] = \alpha * op(A[i]) * op(B[i]) + \beta * C[i]`. * * @param transa A transpose operation of a matrix A (GEMM_OP_N or GEMM_OP_T). * @param transb A transpose operation of a matrix B (GEMM_OP_N or GEMM_OP_T). * @param m A number of rows of a matrix op(A) and C. * @param n A number of columns of a matrix op(B) or C. * @param k A number of columns of op(A) and rows of op(B). * @param A An array of device pointers of batch of input A matrices. * @param Atype A data type of A (TYPE_FP16 or TYPE_FP32) * @param lda A leading dimension of the matrix A. * @param strideA An offset in number of elements between matrix A[i] and A[i+1]. * @param B An array of device pointers of batch of input B matrices. * @param Btype A data type of B (TYPE_FP16 or TYPE_FP32) * @param ldb A leading dimension of the matrix B. * @param strideB An offset in number of elements between matrix B[i] and B[i+1]. * @param C (Output) An array of device pointers of batch of output C matrices. * @param Ctype A data type of C (TYPE_FP16 or TYPE_FP32) * @param ldc A leading dimension of the matrix C. * @param strideC An offset in number of elements between matrix C[i] and C[i+1]. * @param compute_type An accumulation type of GEMM. * @param alpha A scale factor for A*B (default: 1.0f). * @param beta A scale factor for C (default: 0.0f). * * @throw GemmNotSupportedException if a type is not TYPE_FP16 or TYPE_FP32. * @throw std::runtime_error if any exception inside CUDA. */ virtual void stridedBatchedGemm(GemmOp transa, GemmOp transb, const size_t m, const size_t n, const size_t k, const void* A, DataType Atype, const size_t lda, const int64_t strideA, const void* B, DataType Btype, const size_t ldb, const int64_t strideB, void* C, DataType Ctype, const size_t ldc, const int64_t strideC, const size_t batch_size, DataType compute_type, const float alpha = 1.0f, const float beta = 0.0f); protected: IAllocator* allocator_ = nullptr; cudaStream_t stream_; std::mutex* mutex_ = nullptr; cublasAlgoMap* cublas_algo_map_ = nullptr; cublasHandle_t cublas_handle_; cublasLtHandle_t cublaslt_handle_; void* workspace_ = nullptr; // use FP32 as default DataType a_type_ = TYPE_FP32; DataType b_type_ = TYPE_FP32; DataType c_type_ = TYPE_FP32; DataType compute_type_ = TYPE_FP32; // Check if data and inputs are valid in the Gemm class. virtual void checkDataTypeValidity(const DataType& type); }; // class Int8Gemm : public Gemm { // protected: // bool use_ORDER_COL32_2R_4R4_; // what is this? // }; #ifdef SPARSITY_ENABLED /** * A Sparse Gemm class. * * NOTE: * A, B, C are assumed to have a row major layout. * There are two restrictions: * - It supports the case when the matrix B is sparse. * - Supported only TYPE_FP16 for in/out data or compute types. */ class SpGemm: public Gemm { protected: cusparseLtHandle_t cusparselt_handle_; std::map a_desc_map_; std::map b_desc_map_; std::map c_desc_map_; bool useBaseGemm(size_t batch_size, size_t m, size_t n, size_t k); public: using Gemm::setComputeType; using Gemm::setTypes; using Gemm::setDefaultTypes; using Gemm::setAllocator; using Gemm::setCudaStream; using Gemm::gemm; using Gemm::batchedGemm; using Gemm::stridedBatchedGemm; /** * @param allocator Resource allocator. * @param stream A CUDA stream. * @param config_file A file path of a GEMM configuration. */ // TODO: Let's unify algo map loading part. SpGemm(IAllocator* allocator, cudaStream_t stream, std::string config_file = GEMM_CONFIG, std::string spconfig_file = SPGEMM_CONFIG); ~SpGemm(); std::string toString() override; void loadGemmConfig(std::string config_file, std::string spconfig_file); // Template method cannot be overridden. void gemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, const void* input, const DenseWeight& weight, void* output, const float alpha = 1.0f, const float beta = 0.0f) override; void gemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, const void* input, const DenseWeight& weight, void* output, const float alpha = 1.0f, const float beta = 0.0f) override; void gemm(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, const void* A, const DataType Atype, const size_t lda, const void* B, const DataType Btype, const size_t ldb, void* C, const DataType Ctype, const size_t ldc, const float alpha = 1.0f, const float beta = 0.0f) override; private: void checkDataTypeValidity(const DataType& type) override; // Temporal gemm helper mtehod to use template T. template void weightGemmHelper(const GemmOp transa, const GemmOp transb, const size_t m, const size_t n, const size_t k, const void* input, const DenseWeight& weight, void* output, const float alpha, const float beta); }; // class Int8SpGemm : public Int8Gemm, public SpGemm { // }; #endif /* ***************************** GEMM Exceptions ******************************* */ class GemmInvalidShapeException: public std::exception { private: std::string msg_ = "Invalid matrix shapes."; public: explicit GemmInvalidShapeException() = default; template explicit GemmInvalidShapeException(const std::string format, const Args&... args): msg_(fmtstr(format, args...)) { } const char* what() const throw() { return msg_.c_str(); } }; class GemmNotSupportedException: public std::exception { private: std::string msg_ = "Not supported exception."; public: explicit GemmNotSupportedException() = default; template explicit GemmNotSupportedException(const std::string format, const Args&... args): msg_(fmtstr(format, args...)) { } const char* what() const throw() { return msg_.c_str(); } }; class GemmInvalidException: public std::exception { private: std::string msg_ = "Invalid use of gemm."; public: explicit GemmInvalidException() = default; template explicit GemmInvalidException(const std::string format, const Args&... args): msg_(fmtstr(format, args...)) { } const char* what() const throw() { return msg_.c_str(); } }; /* ************************ End of GEMM Exceptions ************************ */ /* ***************************** GEMM utils ******************************* */ /** * @brief Create method for the Gemm family. * * @param allocator Resource allocator. * @param stream A CUDA stream. * @param sparse Whether to use sparse GEMM * @param quantized Whether to use int8 quantized GEMM. * @return A shared pointer of a GemmCls instance. */ std::shared_ptr createGemm(IAllocator* allocator, cudaStream_t stream, bool sparse = false, bool quantized = false); cudaDataType_t getCublasDataType(DataType dtype); // #if (CUDART_VERSION >= 11000) // cublasComputeType_t getCublasComputeType(DataType dtype); // #else cudaDataType_t getCublasComputeType(DataType dtype); // #endif cublasOperation_t getCublasOperation(GemmOp op); std::string getGemmOpString(const GemmOp& op); #ifdef SPARSITY_ENABLED cusparseOperation_t getCusparseOperation(GemmOp op); cusparseComputeType getCusparseComputeType(DataType dtype); /** * @brief Prune a weight matrix (in-place). * * SpGemm supports a case when the sparse matrix is B in C=A*B. * * @param data A data pointer * @param stream A cuda stream object. * @param k A number of rows of op(B). * @param n A number of columns of op(B). * @param trans A transpose operation that will be applied to the matrix * (default: GEMM_OP_N). */ void pruneMatrixB( void* data, const cudaStream_t& stream, const size_t k, const size_t n, const GemmOp trans = GEMM_OP_N); /** * @brief Compress the B matrix in a specific sparsity format. * * @param output A pointer where to allocate memory buffer to store a compressed matrix. * @param alloactor A resource allocator. * @param stream A cuda stream object. * @param input An input matrix to compress. * @param k A number of rows of op(B). * @param n A number of columns of op(B). * @param trans A transpose operation that will be applied to the matrix (default: GEMM_OP_N). * * @return A size of the allocated device buffer of the compressed matrix. * * @throw GemmInvalidException if the input matrix does not have 2:4 sparsity. * or if fail to compute a correct buffer size to store the compressed matrix. * @throw std::runtime_error if any exception inside CUDA. */ size_t compressMatrixB(void** output, IAllocator& allocator, const cudaStream_t& stream, const void* input, const size_t k, const size_t n, const GemmOp trans = GEMM_OP_N); #endif /* ************************* End of GEMM utils **************************** */ } // end of namespace turbomind