cublaslt_gemm.cu 7.15 KB
Newer Older
1
2
3
4
5
6
7
// Copyright(c) Microsoft Corporation.
// Licensed under the MIT License.

#include <getopt.h>
#include <memory>
#include <stdio.h>

8
#include <cuda.h>
9
10
11
#include <cuda_fp16.h>
#include <cuda_fp8.h>

12
13
14
15
16
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
using fp4e2m1 = __nv_fp4_e2m1;
#endif

17
18
#include "cublaslt_utils.h"

19
20
21
22
using fp64 = double;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
23
24
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
25
using int8 = int8_t;
26
27
28
29
30
31
32
33

struct Args {
    int m = 16;
    int n = 16;
    int k = 16;
    int batch = 0;
    int warmup = 20;
    int iter = 50;
34
35
36
37
    // Default warmup iterations for autotune
    int warmup_autotune = 20;
    // Default repeat iterations for autotune
    int iter_autotune = 50;
38
    std::string in_type = "fp8e4m3";
39
    bool autotune = false;
40
41
42
};

void process_args(int argc, char **argv, Args *args) {
43
    const char *const short_opts = "m:n:k:b:w:i:t:aI:W:";
44
45
46
47
48
    const option long_opts[] = {
        {"batch", required_argument, nullptr, 'b'},
        {"warmup", required_argument, nullptr, 'w'},
        {"iter", required_argument, nullptr, 'i'},
        {"in_type", required_argument, nullptr, 't'},
49
50
51
        {"autotune", no_argument, nullptr, 'a'},
        {"iter-autotune", required_argument, nullptr, 'I'},
        {"warmup-autotune", required_argument, nullptr, 'W'},
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    };

    int opt = 0;
    while ((opt = getopt_long(argc, argv, short_opts, long_opts, nullptr)) != -1) {
        switch (opt) {
        case 'm':
            args->m = std::stoi(optarg);
            break;
        case 'n':
            args->n = std::stoi(optarg);
            break;
        case 'k':
            args->k = std::stoi(optarg);
            break;
        case 'b':
            args->batch = std::stoi(optarg);
            break;
        case 'w':
            args->warmup = std::stoi(optarg);
            break;
        case 'i':
            args->iter = std::stoi(optarg);
            break;
        case 't':
            args->in_type = std::string(optarg);
            break;
78
79
80
81
82
83
84
85
86
        case 'a':
            args->autotune = true;
            break;
        case 'I':
            args->iter_autotune = std::stoi(optarg);
            break;
        case 'W':
            args->warmup_autotune = std::stoi(optarg);
            break;
87
88
89
90
        }
    }
}

91
template <typename T> __global__ void init_matrix(T *matrix, const fp32 val, const size_t N) {
92
93
94
95
96
97
98
    size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) {
        matrix[i] = T(val);
    }
}

template <typename T> cudaDataType_t get_datatype() {
99
100
101
102
    if (std::is_same<T, fp64>::value)
        return CUDA_R_64F;
    if (std::is_same<T, fp32>::value)
        return CUDA_R_32F;
103
104
    if (std::is_same<T, fp16>::value)
        return CUDA_R_16F;
105
106
    if (std::is_same<T, bf16>::value)
        return CUDA_R_16BF;
107
108
109
110
    if (std::is_same<T, fp8e4m3>::value)
        return CUDA_R_8F_E4M3;
    if (std::is_same<T, fp8e5m2>::value)
        return CUDA_R_8F_E5M2;
111
112
113
114
#if CUDA_VERSION >= 12080
    if (std::is_same<T, fp4e2m1>::value)
        return CUDA_R_4F_E2M1;
#endif
115
116
    if (std::is_same<T, int8>::value)
        return CUDA_R_8I;
117
118
119
    throw std::invalid_argument("Unknown type");
}

120
template <typename Ta, typename Tb, typename Tout, typename Tc>
121
122
float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, int iter, bool autotune,
                       int iter_autotune, int warmup_autotune) {
123
124
125
    // init matrix
    Ta *matrix_a = nullptr;
    Tb *matrix_b = nullptr;
126
    Tc *matrix_c = nullptr;
127
    Tout *matrix_out = nullptr;
128
129
130
    batch = std::max<size_t>(batch, 1);
    cudaMalloc(&matrix_a, m * k * batch * sizeof(Ta));
    cudaMalloc(&matrix_b, k * n * batch * sizeof(Tb));
131
    cudaMalloc(&matrix_c, m * n * batch * sizeof(Tc));
132
    cudaMalloc(&matrix_out, m * n * batch * sizeof(Tout));
133

134
135
    init_matrix<Ta><<<216, 1024>>>(matrix_a, 1.f, m * k * batch);
    init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * batch);
136
    init_matrix<Tc><<<216, 1024>>>(matrix_c, 3.f, m * n * batch);
137
138

    // init gemm
139
    size_t lda = k, ldb = k, ldc = m, ldd = m;
140
141
    std::unique_ptr<cublasLtGemm> gemm = std::make_unique<cublasLtGemm>();
    gemm->Init();
142
143
    gemm->Setup(m, n, k, batch, lda, ldb, ldc, ldd, get_datatype<Ta>(), get_datatype<Tb>(), get_datatype<Tc>(),
                get_datatype<Tout>(), CUBLAS_OP_T, CUBLAS_OP_N, CUBLASLT_EPILOGUE_DEFAULT);
144
145

    void *workspace = nullptr;
146
147
148
149
150
151
152
153
154
155
    size_t workspace_size;

    if (autotune) {
        workspace_size = gemm->GetAlgorithmExhaustive(
            8, 2 * m * n, 1.0f, 0.0f, reinterpret_cast<void *>(matrix_a), reinterpret_cast<void *>(matrix_b),
            reinterpret_cast<void *>(matrix_out), reinterpret_cast<void *>(matrix_out), iter_autotune, warmup_autotune);
    } else {
        workspace_size = gemm->GetAlgorithm(1, 2 * m * n);
    }

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    cudaMalloc(&workspace, workspace_size);

    // timer
    float time;
    cudaEvent_t startTime, endTime;
    cudaEventCreate(&startTime);
    cudaEventCreate(&endTime);

    for (int i = 0; i < warmup; i++)
        gemm->Execute(reinterpret_cast<void *>(matrix_a), reinterpret_cast<void *>(matrix_b),
                      reinterpret_cast<void *>(matrix_out), reinterpret_cast<void *>(matrix_out), 1.f, 0.f, workspace,
                      workspace_size, 0);
    cudaEventRecord(startTime, 0);
    for (int i = 0; i < iter; i++)
        gemm->Execute(reinterpret_cast<void *>(matrix_a), reinterpret_cast<void *>(matrix_b),
                      reinterpret_cast<void *>(matrix_out), reinterpret_cast<void *>(matrix_out), 1.f, 0.f, workspace,
                      workspace_size, 0);
    cudaEventRecord(endTime, 0);
    cudaEventSynchronize(endTime);
    cudaEventElapsedTime(&time, startTime, endTime);

    // deallocate
    cudaFree(workspace);
    cudaFree(matrix_a);
    cudaFree(matrix_b);
    cudaFree(matrix_out);
    return (time * 1e3 / iter);
}

185
186
187
template <typename Ta, typename Tb = Ta, typename Tout = Ta, typename Tc = Tout> void run(const Args *args) {
    float time_us = timing_matmul_tn<Ta, Tb, Tout, Tc>(args->m, args->n, args->k, args->batch, args->warmup, args->iter,
                                                       args->autotune, args->iter_autotune, args->warmup_autotune);
188
189
190
191
192
193
194
195
196
    // m n k batch time_us tflops
    printf("%d\t%d\t%d\t%d\t%f\t%f\n", args->m, args->n, args->k, args->batch, time_us,
           float(args->m) * float(args->n) * float(2 * args->k - 1) / 1e6 / time_us * std::max(args->batch, 1));
}

int main(int argc, char **argv) {
    Args args;
    process_args(argc, argv, &args);

197
198
199
200
201
    if (args.in_type == "fp64")
        run<fp64>(&args);
    else if (args.in_type == "fp32")
        run<fp32>(&args);
    else if (args.in_type == "fp16")
202
        run<fp16>(&args);
203
204
    else if (args.in_type == "bf16")
        run<bf16>(&args);
205
    else if (args.in_type == "fp8e4m3")
206
        run<fp8e4m3, fp8e4m3, fp16>(&args);
207
    else if (args.in_type == "fp8e5m2")
208
        run<fp8e5m2, fp8e4m3, fp16>(&args);
209
210
211
212
#if CUDA_VERSION >= 12080
    else if (args.in_type == "fp4e2m1")
        run<fp4e2m1, fp4e2m1, fp4e2m1, fp16>(&args);
#endif
213
214
    else if (args.in_type == "int8")
        run<int8>(&args);
215
216
217
218
219
    else
        throw std::invalid_argument("Unknown type " + args.in_type);

    return 0;
}