cublaslt_utils.cc 17.3 KB
Newer Older
1
2
3
4
// Copyright(c) Microsoft Corporation.
// Licensed under the MIT License.

#include "cublaslt_utils.h"
5
6
#include <algorithm> // for std::sort
#include <cassert>   // for assert
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include <cuda.h>
#include <cuda_fp8.h>

#if CUDA_VERSION >= 12080
int GetScaleTensorSize(int inner, int outer, cublasLtMatmulMatrixScale_t scale_mode) {
    if (scale_mode == CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F) {
        return 1;
    }
    if (scale_mode == CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0 ||
        scale_mode == CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3) {
        const auto s_vscale = scale_mode == CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0 ? 32 : 16;
        const auto s_block_cols = 32;
        const auto s_block_rows = 4;
        const auto s_block_inner = 4;
        const auto block_rows = s_block_inner * s_vscale;
        const auto block_cols = s_block_cols * s_block_rows;
        const auto round_off = [](auto x, auto granularity) {
            return granularity * ((x + (granularity - 1)) / granularity);
        };
        const auto s_rows = round_off(inner, block_rows) / s_vscale;
        const auto s_cols = round_off(outer, block_cols);
        return s_rows * s_cols;
    }
    return 0;
}
#endif
33
34
35

void cublasLtGemm::Init() {
    cublasLtHandle_t handle;
36
    CUBLAS_CHECK(cublasLtCreate(&handle));
37
38
39
40
    handle_.reset(handle);

    /* preference can be initialized without arguments */
    cublasLtMatmulPreference_t preference;
41
    CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference));
42
43
44
    preference_.reset(preference);
}

45
46
void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int ldc, int ldd, cudaDataType_t a_type,
                         cudaDataType_t b_type, cudaDataType_t c_type, cudaDataType_t d_type, cublasOperation_t transa,
47
48
49
50
                         cublasOperation_t transb, cublasLtEpilogue_t epilogue,
                         void *a_scale_inverse, /* only need to be set for fp8 */
                         void *b_scale_inverse  /* only need to be set for fp8 */
) {
51
52
53
54
55
    // Store dimensions
    m_ = m;
    n_ = n;
    k_ = k;

56
57
    cublasLtMatrixLayout_t a_desc = nullptr, b_desc = nullptr, c_desc = nullptr, d_desc = nullptr;
    // Create matrix descriptors.
58
    CUBLAS_CHECK(
59
        cublasLtMatrixLayoutCreate(&a_desc, a_type, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
60
    CUBLAS_CHECK(
61
        cublasLtMatrixLayoutCreate(&b_desc, b_type, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
62
    CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&c_desc, c_type, m, n, ldc));
63
    CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&d_desc, d_type, m, n, ldd));
64
65
66

    // strided batch gemm
    if (batch > 0) {
67
68
        int64_t stridea = static_cast<int64_t>(m) * k, strideb = static_cast<int64_t>(k) * n,
                stridec = static_cast<int64_t>(m) * n, strided = static_cast<int64_t>(m) * n;
69
        CUBLAS_CHECK(
70
            cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
71
72
73
        CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea,
                                                      sizeof(stridea)));
        CUBLAS_CHECK(
74
            cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
75
76
77
        CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb,
                                                      sizeof(strideb)));
        CUBLAS_CHECK(
78
            cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
79
80
81
        CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec,
                                                      sizeof(stridec)));
        CUBLAS_CHECK(
82
            cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
83
84
        CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strided,
                                                      sizeof(strided)));
85
86
87
88
89
90
    }
    a_desc_.reset(a_desc);
    b_desc_.reset(b_desc);
    c_desc_.reset(c_desc);
    d_desc_.reset(d_desc);

91
92
93
94
    // Set compute type and scale type based on input types
    cublasComputeType_t gemm_compute_type;
    cudaDataType_t scale_type;
    if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) {
95
        gemm_compute_type = CUBLAS_COMPUTE_32F;
96
97
98
99
        scale_type = CUDA_R_32F;
    } else if (a_type == CUDA_R_16F || b_type == CUDA_R_16F || a_type == CUDA_R_16BF || b_type == CUDA_R_16BF) {
        gemm_compute_type = CUBLAS_COMPUTE_32F;
        scale_type = CUDA_R_32F;
100
101
102
103
104
#if CUDA_VERSION >= 12080
    } else if (a_type == CUDA_R_4F_E2M1 || b_type == CUDA_R_4F_E2M1) {
        gemm_compute_type = CUBLAS_COMPUTE_32F;
        scale_type = CUDA_R_32F;
#endif
105
    } else if (a_type == CUDA_R_64F || b_type == CUDA_R_64F) {
106
        gemm_compute_type = CUBLAS_COMPUTE_64F;
107
108
        scale_type = CUDA_R_64F;
    } else if (a_type == CUDA_R_8I) {
109
        gemm_compute_type = CUBLAS_COMPUTE_32I;
110
111
112
113
114
        scale_type = CUDA_R_32I;
    } else {
        gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
        scale_type = CUDA_R_32F;
    }
115
116

    cublasLtMatmulDesc_t op_desc = nullptr;
117
    CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, scale_type));
118
119
120
121
    op_desc_.reset(op_desc);

    if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) {
        int8_t fastAccuMode = 1;
122
123
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode,
                                                    sizeof(fastAccuMode)));
124
125
    }

126
127
    CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
    CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));
128
129

    if (a_scale_inverse != nullptr) {
130
131
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                    &a_scale_inverse, sizeof(a_scale_inverse)));
132
133
    }
    if (b_scale_inverse != nullptr) {
134
135
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                    &b_scale_inverse, sizeof(b_scale_inverse)));
136
    }
137
    CUBLAS_CHECK(
138
        cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

#if CUDA_VERSION >= 12080
    if (a_type == CUDA_R_4F_E2M1 || b_type == CUDA_R_4F_E2M1) {
        // Allocate and copy device scale values
        const auto a_scale = __nv_fp8_e4m3{1.f}, b_scale = __nv_fp8_e4m3{1.f}, d_out_scale = __nv_fp8_e4m3{1.f};
        const auto d_scale = 1.f;
        void *AscaleDev, *BscaleDev, *DscaleDev, *DOutscaleDev;

        // Set scale modes
        cublasLtMatmulMatrixScale_t AScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
        cublasLtMatmulMatrixScale_t BScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
        cublasLtMatmulMatrixScale_t DScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
        cublasLtMatmulMatrixScale_t DOutScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &AScaleMode,
                                                    sizeof(AScaleMode)));
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &BScaleMode,
                                                    sizeof(BScaleMode)));
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &DScaleMode,
                                                    sizeof(DScaleMode)));
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_D_OUT_SCALE_MODE,
                                                    &DOutScaleMode, sizeof(DOutScaleMode)));

        const auto a_scale_size =
            GetScaleTensorSize(transa != CUBLAS_OP_N ? k : m, transa != CUBLAS_OP_N ? m : k, AScaleMode);
        const auto b_scale_size =
            GetScaleTensorSize(transb != CUBLAS_OP_N ? n : k, transb != CUBLAS_OP_N ? k : n, BScaleMode);
        const auto d_scale_size = GetScaleTensorSize(m, n, DScaleMode);
        const auto d_out_scale_size = GetScaleTensorSize(m, n, DOutScaleMode);

        if (a_scale_size > 0) {
            __nv_fp8_e4m3 *a_scale_host = new __nv_fp8_e4m3[a_scale_size];
            std::fill_n(a_scale_host, a_scale_size, a_scale);
            cudaMalloc(&AscaleDev, a_scale_size * sizeof(__nv_fp8_e4m3));
            cudaMemcpy(AscaleDev, a_scale_host, a_scale_size * sizeof(__nv_fp8_e4m3), cudaMemcpyHostToDevice);
            delete[] a_scale_host;
        }
        if (b_scale_size > 0) {
            __nv_fp8_e4m3 *b_scale_host = new __nv_fp8_e4m3[b_scale_size];
            std::fill_n(b_scale_host, b_scale_size, b_scale);
            cudaMalloc(&BscaleDev, b_scale_size * sizeof(__nv_fp8_e4m3));
            cudaMemcpy(BscaleDev, b_scale_host, b_scale_size * sizeof(__nv_fp8_e4m3), cudaMemcpyHostToDevice);
            delete[] b_scale_host;
        }
        if (d_scale_size > 0) {
            float *d_scale_host = new float[d_scale_size];
            std::fill_n(d_scale_host, d_scale_size, d_scale);
            cudaMalloc(&DscaleDev, d_scale_size * sizeof(float));
            cudaMemcpy(DscaleDev, d_scale_host, d_scale_size * sizeof(float), cudaMemcpyHostToDevice);
            delete[] d_scale_host;
        }
        if (d_out_scale_size > 0) {
            __nv_fp8_e4m3 *d_out_scale_host = new __nv_fp8_e4m3[d_out_scale_size];
            std::fill_n(d_out_scale_host, d_out_scale_size, d_out_scale);
            cudaMalloc(&DOutscaleDev, d_out_scale_size * sizeof(__nv_fp8_e4m3));
            cudaMemcpy(DOutscaleDev, d_out_scale_host, d_out_scale_size * sizeof(__nv_fp8_e4m3),
                       cudaMemcpyHostToDevice);
            delete[] d_out_scale_host;
        }

        // Use device scale pointer attributes
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &AscaleDev,
                                                    sizeof(void *)));
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &BscaleDev,
                                                    sizeof(void *)));
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &DscaleDev,
                                                    sizeof(void *)));
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_D_OUT_SCALE_POINTER,
                                                    &DOutscaleDev, sizeof(void *)));
    }
#endif
209
210
211
}

size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_size) {
212
213
    CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
                                                      &max_workspace_size, sizeof(max_workspace_size)));
214
215
    int found_algorithm_count = 0;
    std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
216
217
218
    CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
                                                c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count,
                                                results.data(), &found_algorithm_count));
219
220
221
222
223
224
225
226
227
    if (found_algorithm_count == 0) {
        throw std::runtime_error("Unable to find any suitable algorithms");
    }

    results.resize(found_algorithm_count);
    heuristic_results_ = std::move(results);
    return heuristic_results_.front().workspaceSize;
}

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
size_t cublasLtGemm::GetAlgorithmExhaustive(int max_algorithm_count, size_t max_workspace_size, float alpha, float beta,
                                            void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d,
                                            int repeat_iterations, int warmup_iterations) {
    // Set workspace size in preference
    CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
                                                      &max_workspace_size, sizeof(max_workspace_size)));

    // Get heuristic algorithms
    int found_algorithm_count = 0;
    std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
    CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
                                                c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count,
                                                results.data(), &found_algorithm_count));
    if (found_algorithm_count == 0) {
        throw std::runtime_error("Unable to find any suitable algorithms");
    }

    results.resize(found_algorithm_count);
    heuristic_results_ = std::move(results);

    // Create stream and events for timing
    cudaStream_t stream;
    cudaEvent_t startEvent, stopEvent;
    cudaStreamCreate(&stream);
    cudaEventCreate(&startEvent);
    cudaEventCreate(&stopEvent);

    // Test each algorithm multiple times to find the best one
    std::vector<float> algoTimes(repeat_iterations);

    // Allocate workspace
    void *workspace = nullptr;
    cudaMalloc(&workspace, max_workspace_size);

    // Test each algorithm
    algo_metrics_.clear();
    algo_metrics_.reserve(found_algorithm_count);

    for (int algoIdx = 0; algoIdx < found_algorithm_count; algoIdx++) {
        // Skip algorithms that require more workspace than available
        if (heuristic_results_[algoIdx].workspaceSize > max_workspace_size) {
            continue;
        }

        // warmup
        for (int warmupIdx = 0; warmupIdx < warmup_iterations; warmupIdx++) {
            cublasStatus_t status =
                cublasLtMatmul(handle_.get(), op_desc_.get(), &alpha, matrix_a, a_desc_.get(), matrix_b, b_desc_.get(),
                               &beta, matrix_c, c_desc_.get(), matrix_d, d_desc_.get(),
                               &heuristic_results_[algoIdx].algo, workspace, max_workspace_size, stream);
        }

        // Test each algorithm multiple times
        cudaEventRecord(startEvent, stream);
        for (int checkIdx = 0; checkIdx < repeat_iterations; checkIdx++) {
            cublasStatus_t status =
                cublasLtMatmul(handle_.get(), op_desc_.get(), &alpha, matrix_a, a_desc_.get(), matrix_b, b_desc_.get(),
                               &beta, matrix_c, c_desc_.get(), matrix_d, d_desc_.get(),
                               &heuristic_results_[algoIdx].algo, workspace, max_workspace_size, stream);

            // Skip if algorithm fails
            if (status != CUBLAS_STATUS_SUCCESS) {
                algoTimes[checkIdx] = std::numeric_limits<float>::max();
                continue;
            }
        }

        cudaEventRecord(stopEvent, stream);
        cudaEventSynchronize(stopEvent);

        float time = 0;
        cudaEventElapsedTime(&time, startEvent, stopEvent);
        algoTimes[algoIdx] = time / repeat_iterations;

        float meanTime = algoTimes[algoIdx];
        float flops = 2.0f * m_ * n_ * k_ / (meanTime * 1e-3f);

        // Store metrics
        AlgorithmMetrics metrics;
        metrics.algo = heuristic_results_[algoIdx].algo;
        metrics.workspace_size = heuristic_results_[algoIdx].workspaceSize;
        metrics.time = meanTime;
        metrics.flops = flops;
        algo_metrics_.push_back(metrics);
    }

    std::sort(algo_metrics_.begin(), algo_metrics_.end(),
              [](const AlgorithmMetrics &a, const AlgorithmMetrics &b) { return a.time < b.time; });

    if (!algo_metrics_.empty())
        heuristic_results_[0].algo = algo_metrics_.front().algo;

    // Clean up resources
    cudaFree(workspace);
    cudaEventDestroy(startEvent);
    cudaEventDestroy(stopEvent);
    cudaStreamDestroy(stream);

    if (!algo_metrics_.empty()) {
        return algo_metrics_.front().workspace_size;
    }

    throw std::runtime_error("No valid algorithms found during autotune");
}

333
334
void cublasLtGemm::Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta,
                           void *workspace, size_t workspace_size, cudaStream_t stream) {
335

336
337
338
339
340
341
    CUBLAS_CHECK(cublasLtMatmul(handle_.get(), op_desc_.get(), static_cast<const void *>(&alpha), /* alpha */
                                matrix_a,                                                         /* A */
                                a_desc_.get(), matrix_b,                                          /* B */
                                b_desc_.get(), static_cast<const void *>(&beta),                  /* beta */
                                matrix_c,                                                         /* C */
                                c_desc_.get(), matrix_d,                                          /* D */
342
                                d_desc_.get(), &heuristic_results_.front().algo, workspace,       /* workspace */
343
                                workspace_size, stream));                                         /* stream */
344
}