cublaslt_gemm.cu 6.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
// Copyright(c) Microsoft Corporation.
// Licensed under the MIT License.

#include <getopt.h>
#include <memory>
#include <stdio.h>

#include <cuda_fp16.h>
#include <cuda_fp8.h>

#include "cublaslt_utils.h"

13
14
15
16
using fp64 = double;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
17
18
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
19
using int8 = int8_t;
20
21
22
23
24
25
26
27

struct Args {
    int m = 16;
    int n = 16;
    int k = 16;
    int batch = 0;
    int warmup = 20;
    int iter = 50;
28
29
30
31
    // Default warmup iterations for autotune
    int warmup_autotune = 20;
    // Default repeat iterations for autotune
    int iter_autotune = 50;
32
    std::string in_type = "fp8e4m3";
33
    bool autotune = false;
34
35
36
};

void process_args(int argc, char **argv, Args *args) {
37
    const char *const short_opts = "m:n:k:b:w:i:t:aI:W:";
38
39
40
41
42
    const option long_opts[] = {
        {"batch", required_argument, nullptr, 'b'},
        {"warmup", required_argument, nullptr, 'w'},
        {"iter", required_argument, nullptr, 'i'},
        {"in_type", required_argument, nullptr, 't'},
43
44
45
        {"autotune", no_argument, nullptr, 'a'},
        {"iter-autotune", required_argument, nullptr, 'I'},
        {"warmup-autotune", required_argument, nullptr, 'W'},
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    };

    int opt = 0;
    while ((opt = getopt_long(argc, argv, short_opts, long_opts, nullptr)) != -1) {
        switch (opt) {
        case 'm':
            args->m = std::stoi(optarg);
            break;
        case 'n':
            args->n = std::stoi(optarg);
            break;
        case 'k':
            args->k = std::stoi(optarg);
            break;
        case 'b':
            args->batch = std::stoi(optarg);
            break;
        case 'w':
            args->warmup = std::stoi(optarg);
            break;
        case 'i':
            args->iter = std::stoi(optarg);
            break;
        case 't':
            args->in_type = std::string(optarg);
            break;
72
73
74
75
76
77
78
79
80
        case 'a':
            args->autotune = true;
            break;
        case 'I':
            args->iter_autotune = std::stoi(optarg);
            break;
        case 'W':
            args->warmup_autotune = std::stoi(optarg);
            break;
81
82
83
84
        }
    }
}

85
template <typename T> __global__ void init_matrix(T *matrix, const fp32 val, const size_t N) {
86
87
88
89
90
91
92
    size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) {
        matrix[i] = T(val);
    }
}

template <typename T> cudaDataType_t get_datatype() {
93
94
95
96
    if (std::is_same<T, fp64>::value)
        return CUDA_R_64F;
    if (std::is_same<T, fp32>::value)
        return CUDA_R_32F;
97
98
    if (std::is_same<T, fp16>::value)
        return CUDA_R_16F;
99
100
    if (std::is_same<T, bf16>::value)
        return CUDA_R_16BF;
101
102
103
104
    if (std::is_same<T, fp8e4m3>::value)
        return CUDA_R_8F_E4M3;
    if (std::is_same<T, fp8e5m2>::value)
        return CUDA_R_8F_E5M2;
105
106
    if (std::is_same<T, int8>::value)
        return CUDA_R_8I;
107
108
109
110
    throw std::invalid_argument("Unknown type");
}

template <typename Ta, typename Tb, typename Tout>
111
112
float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, int iter, bool autotune,
                       int iter_autotune, int warmup_autotune) {
113
114
115
116
    // init matrix
    Ta *matrix_a = nullptr;
    Tb *matrix_b = nullptr;
    Tout *matrix_out = nullptr;
117
118
119
120
    batch = std::max<size_t>(batch, 1);
    cudaMalloc(&matrix_a, m * k * batch * sizeof(Ta));
    cudaMalloc(&matrix_b, k * n * batch * sizeof(Tb));
    cudaMalloc(&matrix_out, m * n * batch * sizeof(Tout));
121

122
123
    init_matrix<Ta><<<216, 1024>>>(matrix_a, 1.f, m * k * batch);
    init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * batch);
124
125

    // init gemm
126
    size_t lda = k, ldb = k, ldd = m;
127
128
129
130
131
132
    std::unique_ptr<cublasLtGemm> gemm = std::make_unique<cublasLtGemm>();
    gemm->Init();
    gemm->Setup(m, n, k, batch, lda, ldb, ldd, get_datatype<Ta>(), get_datatype<Tb>(), get_datatype<Tout>(),
                CUBLAS_OP_T, CUBLAS_OP_N, CUBLASLT_EPILOGUE_DEFAULT);

    void *workspace = nullptr;
133
134
135
136
137
138
139
140
141
142
    size_t workspace_size;

    if (autotune) {
        workspace_size = gemm->GetAlgorithmExhaustive(
            8, 2 * m * n, 1.0f, 0.0f, reinterpret_cast<void *>(matrix_a), reinterpret_cast<void *>(matrix_b),
            reinterpret_cast<void *>(matrix_out), reinterpret_cast<void *>(matrix_out), iter_autotune, warmup_autotune);
    } else {
        workspace_size = gemm->GetAlgorithm(1, 2 * m * n);
    }

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    cudaMalloc(&workspace, workspace_size);

    // timer
    float time;
    cudaEvent_t startTime, endTime;
    cudaEventCreate(&startTime);
    cudaEventCreate(&endTime);

    for (int i = 0; i < warmup; i++)
        gemm->Execute(reinterpret_cast<void *>(matrix_a), reinterpret_cast<void *>(matrix_b),
                      reinterpret_cast<void *>(matrix_out), reinterpret_cast<void *>(matrix_out), 1.f, 0.f, workspace,
                      workspace_size, 0);
    cudaEventRecord(startTime, 0);
    for (int i = 0; i < iter; i++)
        gemm->Execute(reinterpret_cast<void *>(matrix_a), reinterpret_cast<void *>(matrix_b),
                      reinterpret_cast<void *>(matrix_out), reinterpret_cast<void *>(matrix_out), 1.f, 0.f, workspace,
                      workspace_size, 0);
    cudaEventRecord(endTime, 0);
    cudaEventSynchronize(endTime);
    cudaEventElapsedTime(&time, startTime, endTime);

    // deallocate
    cudaFree(workspace);
    cudaFree(matrix_a);
    cudaFree(matrix_b);
    cudaFree(matrix_out);
    return (time * 1e3 / iter);
}

172
173
174
template <typename Ta, typename Tb = Ta, typename Tout = Ta> void run(const Args *args) {
    float time_us = timing_matmul_tn<Ta, Tb, Tout>(args->m, args->n, args->k, args->batch, args->warmup, args->iter,
                                                   args->autotune, args->iter_autotune, args->warmup_autotune);
175
176
177
178
179
180
181
182
183
    // m n k batch time_us tflops
    printf("%d\t%d\t%d\t%d\t%f\t%f\n", args->m, args->n, args->k, args->batch, time_us,
           float(args->m) * float(args->n) * float(2 * args->k - 1) / 1e6 / time_us * std::max(args->batch, 1));
}

int main(int argc, char **argv) {
    Args args;
    process_args(argc, argv, &args);

184
185
186
187
188
    if (args.in_type == "fp64")
        run<fp64>(&args);
    else if (args.in_type == "fp32")
        run<fp32>(&args);
    else if (args.in_type == "fp16")
189
        run<fp16>(&args);
190
191
    else if (args.in_type == "bf16")
        run<bf16>(&args);
192
    else if (args.in_type == "fp8e4m3")
193
        run<fp8e4m3, fp8e4m3, fp16>(&args);
194
    else if (args.in_type == "fp8e5m2")
195
        run<fp8e5m2, fp8e4m3, fp16>(&args);
196
197
    else if (args.in_type == "int8")
        run<int8>(&args);
198
199
200
201
202
    else
        throw std::invalid_argument("Unknown type " + args.in_type);

    return 0;
}