cublasMMWrapper.cpp 14.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cublasVersionCheck.h"
#include <algorithm>

#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#endif

namespace tensorrt_llm
{
namespace common
{

CublasMMWrapper::CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
    std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace)
    : mCublasHandle(cublasHandle)
    , mCublasLtHandle(cublasltHandle)
    , mStream(stream)
    , mCublasWorkspace(workspace)
{
}

CublasMMWrapper::~CublasMMWrapper() {}

CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper)
    : mCublasHandle(wrapper.mCublasHandle)
    , mCublasLtHandle(wrapper.mCublasLtHandle)
    , mStream(wrapper.mStream)
{
}

void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
    int const k, int const lda, int const ldb, int const ldc, int8_t fastAcc)
{
    // --------------------------------------
    // Create descriptors for the original matrices
    check_cuda_error(
        cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
    check_cuda_error(
        cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
    check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc));
    check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType));
    check_cuda_error(cublasLtMatmulDescSetAttribute(
        mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
    check_cuda_error(cublasLtMatmulDescSetAttribute(
        mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
    check_cuda_error(
        cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof(int8_t)));
}

void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b)
{
    check_cuda_error(
        cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(void*)));
    check_cuda_error(
        cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*)));
}

void CublasMMWrapper::destroyDescriptors()
{
    check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc));
    check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc));
    check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc));
    check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc));
    mOperationDesc = NULL;
    mADesc = NULL;
    mBDesc = NULL;
    mCDesc = NULL;
}

void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
    void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc)
{
    Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
}

void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
    void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc,
    std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
{
    if (heuristic)
    {
        Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo,
            (*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
            /* usingCublasLt */ true);
    }
    else
    {
        Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false,
            /* usingCublasLt */ true);
    }
}

void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
    void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
    std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
{
    if (heuristic)
    {
        Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, /* hasAlgo */ (*heuristic).algo,
            (*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
            /* usingCublasLt */ true);
    }
    else
    {
        Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
            /* usingCublasLt */ true);
    }
}

void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
    void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta)
{
    bool usingCublasLt = mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3;

    Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
        /* usingCublasLt */ usingCublasLt);
}

void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
    void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
    cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt)
{
    half h_alpha = (half) (f_alpha);
    half h_beta = (half) (f_beta);

    // TODO: default cublas libs
    usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3);
    bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F;
    int batch_count = 1;
    // fp32 use cublas as default
    // fp16 use cublasLt as default
    void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
    void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
    int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;

    if (usingCublasLt)
    {
        if (hasAlgo)
        {
            hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo);
        }

        check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C,
            mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream));

        sync_check_cuda_error();
    }
    else
    {
        check_cuda_error(cublasSetStream(getCublasHandle(), mStream));
        check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize));
        // Go with default heuristic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+
        cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT;
        check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb,
            beta, C, mCType, ldc, mComputeType, static_cast<cublasGemmAlgo_t>(cublasAlgo)));
        sync_check_cuda_error();
    }
}

void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
    int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb,
    const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha,
    float const f_beta)
{
    half h_alpha = (half) f_alpha;
    half h_beta = (half) f_beta;

    int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
    void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
    void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);

    check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda,
        strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType,
        mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}

void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
    int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA,
    void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C,
    cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType)
{
    half h_alpha = (half) f_alpha;
    half h_beta = (half) f_beta;

    bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
    void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
    void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);

    check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda,
        strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType,
        mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}

void CublasMMWrapper::setWorkspace(void* workspace)
{
    mCublasWorkspace = workspace;
}

void CublasMMWrapper::setFP32GemmConfig()
{
    setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F);
}

void CublasMMWrapper::setFP16GemmConfig(cudaDataType_t outputType)
{
    setGemmConfig(CUDA_R_16F, CUDA_R_16F, outputType, CUDA_R_32F);
}

#ifdef ENABLE_BF16
void CublasMMWrapper::setBF16GemmConfig(cudaDataType_t outputType)
{
    setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, outputType, CUDA_R_32F);
}
#endif

#ifdef ENABLE_FP8
void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
{
    setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F);
}
#endif

void CublasMMWrapper::setGemmConfig(
    cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType)
{
    mAType = aType;
    mBType = bType;
    mCType = cType;
    bool isFp16ComputeType = computeType == CUDA_R_16F;
    if (isFp16ComputeType)
    {
        mComputeType = CUBLAS_COMPUTE_16F;
        mScaleType = CUDA_R_16F;
    }
    else
    {
        mComputeType = CUBLAS_COMPUTE_32F;
        mScaleType = CUDA_R_32F;
    }
}

CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
{
    if (data_type == CUDA_R_16F)
    {
        return HALF_DATATYPE;
    }
    else if (data_type == CUDA_R_32F)
    {
        return FLOAT_DATATYPE;
    }
    else if (data_type == CUDA_R_8I)
    {
        return INT8_DATATYPE;
    }
#ifdef ENABLE_BF16
    else if (data_type == CUDA_R_16BF)
    {
        return BFLOAT16_DATATYPE;
    }
#endif
    return FLOAT_DATATYPE;
}

void CublasMMWrapper::setStream(cudaStream_t stream)
{
    mStream = stream;
}

bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
    int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo)
{
    TLLM_CHECK_WITH_INFO(
        descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");

    int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;

    cublasLtMatmulHeuristicResult_t heurResult;
    cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
        getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult);

    if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS
        || heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE)
    {
        return false;
    }

    sync_check_cuda_error();

    return true;
}

std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa,
    cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc)
{
    TLLM_CHECK_WITH_INFO(
        descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");

    auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);

    sync_check_cuda_error();

    return heuristics;
}

std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle,
    cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
    cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc)
{
#if TLLM_CUBLAS_VER_LE(11, 4, 2)
    TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2.");
    return {};
#else
    std::vector<cublasLtMatmulHeuristicResult_t> heuristics(200);
    cublasLtMatmulPreference_t preference;
    check_cuda_error(cublasLtMatmulPreferenceCreate(&preference));
    check_cuda_error(cublasLtMatmulPreferenceInit(preference));
    uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE;
    check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
        preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)));
    // Restrict reduction algorithms for numerical stability and better determinism
    uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK;
    check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
        preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask)));
#if TLLM_CUBLAS_VER_LT(12, 0, 0)
    uint32_t pointer_mode_mask = 0;
    check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
        preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask)));
#endif

    int return_count = 0;
    check_cuda_error(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
        heuristics.size(), heuristics.data(), &return_count));
    heuristics.resize(return_count);

    return heuristics;
#endif
}

} // namespace common

} // namespace tensorrt_llm