cublaslt_utils.cc 11.8 KB
Newer Older
1
2
3
4
// Copyright(c) Microsoft Corporation.
// Licensed under the MIT License.

#include "cublaslt_utils.h"
5
6
#include <algorithm> // for std::sort
#include <cassert>   // for assert
7
8
9

void cublasLtGemm::Init() {
    cublasLtHandle_t handle;
10
    CUBLAS_CHECK(cublasLtCreate(&handle));
11
12
13
14
    handle_.reset(handle);

    /* preference can be initialized without arguments */
    cublasLtMatmulPreference_t preference;
15
    CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference));
16
17
18
19
20
21
22
23
24
    preference_.reset(preference);
}

void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int ldd, cudaDataType_t a_type,
                         cudaDataType_t b_type, cudaDataType_t d_type, cublasOperation_t transa,
                         cublasOperation_t transb, cublasLtEpilogue_t epilogue,
                         void *a_scale_inverse, /* only need to be set for fp8 */
                         void *b_scale_inverse  /* only need to be set for fp8 */
) {
25
26
27
28
29
    // Store dimensions
    m_ = m;
    n_ = n;
    k_ = k;

30
31
    cublasLtMatrixLayout_t a_desc = nullptr, b_desc = nullptr, c_desc = nullptr, d_desc = nullptr;
    // force c_type
32
    cudaDataType_t c_type = d_type;
33
    // Create matrix descriptors.
34
    CUBLAS_CHECK(
35
        cublasLtMatrixLayoutCreate(&a_desc, a_type, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
36
    CUBLAS_CHECK(
37
        cublasLtMatrixLayoutCreate(&b_desc, b_type, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
38
39
    CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&c_desc, c_type, m, n, ldd));
    CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&d_desc, d_type, m, n, ldd));
40
41
42

    // strided batch gemm
    if (batch > 0) {
43
44
        int64_t stridea = static_cast<int64_t>(m) * k, strideb = static_cast<int64_t>(k) * n,
                stridec = static_cast<int64_t>(m) * n, strided = static_cast<int64_t>(m) * n;
45
        CUBLAS_CHECK(
46
            cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
47
48
49
        CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea,
                                                      sizeof(stridea)));
        CUBLAS_CHECK(
50
            cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
51
52
53
        CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb,
                                                      sizeof(strideb)));
        CUBLAS_CHECK(
54
            cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
55
56
57
        CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec,
                                                      sizeof(stridec)));
        CUBLAS_CHECK(
58
            cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
59
60
        CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strided,
                                                      sizeof(strided)));
61
62
63
64
65
66
    }
    a_desc_.reset(a_desc);
    b_desc_.reset(b_desc);
    c_desc_.reset(c_desc);
    d_desc_.reset(d_desc);

67
68
69
70
71
    // Set compute type and scale type based on input types
    cublasComputeType_t gemm_compute_type;
    cudaDataType_t scale_type;

    if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) {
72
        gemm_compute_type = CUBLAS_COMPUTE_32F;
73
74
75
76
77
        scale_type = CUDA_R_32F;
    } else if (a_type == CUDA_R_16F || b_type == CUDA_R_16F || a_type == CUDA_R_16BF || b_type == CUDA_R_16BF) {
        gemm_compute_type = CUBLAS_COMPUTE_32F;
        scale_type = CUDA_R_32F;
    } else if (a_type == CUDA_R_64F || b_type == CUDA_R_64F) {
78
        gemm_compute_type = CUBLAS_COMPUTE_64F;
79
80
        scale_type = CUDA_R_64F;
    } else if (a_type == CUDA_R_8I) {
81
        gemm_compute_type = CUBLAS_COMPUTE_32I;
82
83
84
85
86
        scale_type = CUDA_R_32I;
    } else {
        gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
        scale_type = CUDA_R_32F;
    }
87
88

    cublasLtMatmulDesc_t op_desc = nullptr;
89
    CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, scale_type));
90
91
92
93
    op_desc_.reset(op_desc);

    if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) {
        int8_t fastAccuMode = 1;
94
95
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode,
                                                    sizeof(fastAccuMode)));
96
97
    }

98
99
    CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
    CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));
100
101

    if (a_scale_inverse != nullptr) {
102
103
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                    &a_scale_inverse, sizeof(a_scale_inverse)));
104
105
    }
    if (b_scale_inverse != nullptr) {
106
107
        CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                    &b_scale_inverse, sizeof(b_scale_inverse)));
108
    }
109
    CUBLAS_CHECK(
110
111
112
113
        cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
}

size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_size) {
114
115
    CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
                                                      &max_workspace_size, sizeof(max_workspace_size)));
116
117
    int found_algorithm_count = 0;
    std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
118
119
120
    CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
                                                c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count,
                                                results.data(), &found_algorithm_count));
121
122
123
124
125
126
127
128
129
    if (found_algorithm_count == 0) {
        throw std::runtime_error("Unable to find any suitable algorithms");
    }

    results.resize(found_algorithm_count);
    heuristic_results_ = std::move(results);
    return heuristic_results_.front().workspaceSize;
}

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
size_t cublasLtGemm::GetAlgorithmExhaustive(int max_algorithm_count, size_t max_workspace_size, float alpha, float beta,
                                            void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d,
                                            int repeat_iterations, int warmup_iterations) {
    // Set workspace size in preference
    CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
                                                      &max_workspace_size, sizeof(max_workspace_size)));

    // Get heuristic algorithms
    int found_algorithm_count = 0;
    std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
    CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
                                                c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count,
                                                results.data(), &found_algorithm_count));
    if (found_algorithm_count == 0) {
        throw std::runtime_error("Unable to find any suitable algorithms");
    }

    results.resize(found_algorithm_count);
    heuristic_results_ = std::move(results);

    // Create stream and events for timing
    cudaStream_t stream;
    cudaEvent_t startEvent, stopEvent;
    cudaStreamCreate(&stream);
    cudaEventCreate(&startEvent);
    cudaEventCreate(&stopEvent);

    // Test each algorithm multiple times to find the best one
    std::vector<float> algoTimes(repeat_iterations);

    // Allocate workspace
    void *workspace = nullptr;
    cudaMalloc(&workspace, max_workspace_size);

    // Test each algorithm
    algo_metrics_.clear();
    algo_metrics_.reserve(found_algorithm_count);

    for (int algoIdx = 0; algoIdx < found_algorithm_count; algoIdx++) {
        // Skip algorithms that require more workspace than available
        if (heuristic_results_[algoIdx].workspaceSize > max_workspace_size) {
            continue;
        }

        // warmup
        for (int warmupIdx = 0; warmupIdx < warmup_iterations; warmupIdx++) {
            cublasStatus_t status =
                cublasLtMatmul(handle_.get(), op_desc_.get(), &alpha, matrix_a, a_desc_.get(), matrix_b, b_desc_.get(),
                               &beta, matrix_c, c_desc_.get(), matrix_d, d_desc_.get(),
                               &heuristic_results_[algoIdx].algo, workspace, max_workspace_size, stream);
        }

        // Test each algorithm multiple times
        cudaEventRecord(startEvent, stream);
        for (int checkIdx = 0; checkIdx < repeat_iterations; checkIdx++) {
            cublasStatus_t status =
                cublasLtMatmul(handle_.get(), op_desc_.get(), &alpha, matrix_a, a_desc_.get(), matrix_b, b_desc_.get(),
                               &beta, matrix_c, c_desc_.get(), matrix_d, d_desc_.get(),
                               &heuristic_results_[algoIdx].algo, workspace, max_workspace_size, stream);

            // Skip if algorithm fails
            if (status != CUBLAS_STATUS_SUCCESS) {
                algoTimes[checkIdx] = std::numeric_limits<float>::max();
                continue;
            }
        }

        cudaEventRecord(stopEvent, stream);
        cudaEventSynchronize(stopEvent);

        float time = 0;
        cudaEventElapsedTime(&time, startEvent, stopEvent);
        algoTimes[algoIdx] = time / repeat_iterations;

        float meanTime = algoTimes[algoIdx];
        float flops = 2.0f * m_ * n_ * k_ / (meanTime * 1e-3f);

        // Store metrics
        AlgorithmMetrics metrics;
        metrics.algo = heuristic_results_[algoIdx].algo;
        metrics.workspace_size = heuristic_results_[algoIdx].workspaceSize;
        metrics.time = meanTime;
        metrics.flops = flops;
        algo_metrics_.push_back(metrics);
    }

    std::sort(algo_metrics_.begin(), algo_metrics_.end(),
              [](const AlgorithmMetrics &a, const AlgorithmMetrics &b) { return a.time < b.time; });

    if (!algo_metrics_.empty())
        heuristic_results_[0].algo = algo_metrics_.front().algo;

    // Clean up resources
    cudaFree(workspace);
    cudaEventDestroy(startEvent);
    cudaEventDestroy(stopEvent);
    cudaStreamDestroy(stream);

    if (!algo_metrics_.empty()) {
        return algo_metrics_.front().workspace_size;
    }

    throw std::runtime_error("No valid algorithms found during autotune");
}

235
236
void cublasLtGemm::Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta,
                           void *workspace, size_t workspace_size, cudaStream_t stream) {
237

238
239
240
241
242
243
    CUBLAS_CHECK(cublasLtMatmul(handle_.get(), op_desc_.get(), static_cast<const void *>(&alpha), /* alpha */
                                matrix_a,                                                         /* A */
                                a_desc_.get(), matrix_b,                                          /* B */
                                b_desc_.get(), static_cast<const void *>(&beta),                  /* beta */
                                matrix_c,                                                         /* C */
                                c_desc_.get(), matrix_d,                                          /* D */
244
                                d_desc_.get(), &heuristic_results_.front().algo, workspace,       /* workspace */
245
                                workspace_size, stream));                                         /* stream */
246
}