cublasMMWrapper.h 5.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
/*
 * Copyright (c) 2019-2022, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include "tensorrt_llm/common/cudaUtils.h"
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <map>
#include <optional>
#include <string>

namespace tensorrt_llm
{
namespace common
{

class CublasMMWrapper
{
protected:
    std::shared_ptr<cublasHandle_t> mCublasHandle;
    std::shared_ptr<cublasLtHandle_t> mCublasLtHandle;

    cudaDataType_t mAType{};
    cudaDataType_t mBType{};
    cudaDataType_t mCType{};
    cublasComputeType_t mComputeType{};
    cudaDataType_t mScaleType{};

    cublasLtMatmulDesc_t mOperationDesc{NULL};
    cublasLtMatrixLayout_t mADesc{NULL};
    cublasLtMatrixLayout_t mBDesc{NULL};
    cublasLtMatrixLayout_t mCDesc{NULL};

    cudaStream_t mStream;

    void* mCublasWorkspace = nullptr;

private:
    bool descriptorsCreated() const
    {
        return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL;
    }

public:
    CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle, std::shared_ptr<cublasLtHandle_t> cublasLtHandle,
        cudaStream_t stream, void* workspace);

    ~CublasMMWrapper();

    CublasMMWrapper(CublasMMWrapper const& wrapper);

    /********************** GEMMs **********************/
    void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
        int const lda, void const* B, int const ldb, void* C, int const ldc);

    void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
        int const lda, void const* B, int const ldb, void* C, int const ldc,
        std::optional<cublasLtMatmulHeuristicResult_t> const& algo);

    void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
        int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
        std::optional<cublasLtMatmulHeuristicResult_t> const& algo);

    void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
        int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta);

    void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
        int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
        cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt);

    void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
        void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB,
        void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f,
        float const f_beta = 0.0f);

    void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
        float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B,
        cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType,
        int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType);

    /********************** Tactic selection helpers **********************/
    bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
        int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo);

    std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
        int const m, int const n, int const k, int const lda, int const ldb, int const ldc);

    std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
        cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
        cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc);

    using MatrixLayout = std::tuple<cudaDataType_t, cublasLtOrder_t, uint64_t, uint64_t>;
    using cache_idx_t = std::tuple<cublasLtMatmulDesc_t, std::array<MatrixLayout, 4>>;

    MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc);

    /********************** Utils **********************/
    void setWorkspace(void* workspace);

    void setFP32GemmConfig();
    void setFP16GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
#ifdef ENABLE_BF16
    void setBF16GemmConfig(cudaDataType_t outputType = CUDA_R_16BF);
#endif
#ifdef ENABLE_FP8
    void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
#endif

    void setStream(cudaStream_t stream);

    void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType);

    CublasDataType getCublasDataType(cudaDataType_t data_type);

    void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
        int const lda, int const ldb, int const ldc, int8_t fastAcc = 0);
    void setScaleDescriptors(void* scale_a, void* scale_b);
    void destroyDescriptors();

    cublasHandle_t getCublasHandle()
    {
        return *(this->mCublasHandle);
    }

    cublasLtHandle_t getCublasLtHandle() const
    {
        return *(this->mCublasLtHandle);
    }
};

} // namespace common

} // namespace tensorrt_llm