cublaslt_gemm.cu 5.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
// Copyright(c) Microsoft Corporation.
// Licensed under the MIT License.

#include <getopt.h>
#include <memory>
#include <stdio.h>

#include <cuda_fp16.h>
#include <cuda_fp8.h>

#include "cublaslt_utils.h"

13
14
15
16
using fp64 = double;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;

struct Args {
    int m = 16;
    int n = 16;
    int k = 16;
    int batch = 0;
    int warmup = 20;
    int iter = 50;
    std::string in_type = "fp8e4m3";
};

void process_args(int argc, char **argv, Args *args) {
    const char *const short_opts = "m:n:k:b:w:i:t:";
    const option long_opts[] = {
        {"batch", required_argument, nullptr, 'b'},
        {"warmup", required_argument, nullptr, 'w'},
        {"iter", required_argument, nullptr, 'i'},
        {"in_type", required_argument, nullptr, 't'},
    };

    int opt = 0;
    while ((opt = getopt_long(argc, argv, short_opts, long_opts, nullptr)) != -1) {
        switch (opt) {
        case 'm':
            args->m = std::stoi(optarg);
            break;
        case 'n':
            args->n = std::stoi(optarg);
            break;
        case 'k':
            args->k = std::stoi(optarg);
            break;
        case 'b':
            args->batch = std::stoi(optarg);
            break;
        case 'w':
            args->warmup = std::stoi(optarg);
            break;
        case 'i':
            args->iter = std::stoi(optarg);
            break;
        case 't':
            args->in_type = std::string(optarg);
            break;
        }
    }
}

67
template <typename T> __global__ void init_matrix(T *matrix, const fp32 val, const size_t N) {
68
69
70
71
72
73
74
    size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) {
        matrix[i] = T(val);
    }
}

template <typename T> cudaDataType_t get_datatype() {
75
76
77
78
    if (std::is_same<T, fp64>::value)
        return CUDA_R_64F;
    if (std::is_same<T, fp32>::value)
        return CUDA_R_32F;
79
80
    if (std::is_same<T, fp16>::value)
        return CUDA_R_16F;
81
82
    if (std::is_same<T, bf16>::value)
        return CUDA_R_16BF;
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    if (std::is_same<T, fp8e4m3>::value)
        return CUDA_R_8F_E4M3;
    if (std::is_same<T, fp8e5m2>::value)
        return CUDA_R_8F_E5M2;
    throw std::invalid_argument("Unknown type");
}

template <typename Ta, typename Tb, typename Tout>
float timing_matmul_tn(int m, int n, int k, int batch, int warmup, int iter) {
    // init matrix
    Ta *matrix_a = nullptr;
    Tb *matrix_b = nullptr;
    Tout *matrix_out = nullptr;
    cudaMalloc(&matrix_a, m * k * std::max(batch, 1) * sizeof(Ta));
    cudaMalloc(&matrix_b, k * n * std::max(batch, 1) * sizeof(Tb));
    cudaMalloc(&matrix_out, m * n * std::max(batch, 1) * sizeof(Tout));

100
101
    init_matrix<Ta><<<216, 1024>>>(matrix_a, 1.f, m * k * std::max(batch, 1));
    init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * std::max(batch, 1));
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

    // init gemm
    int lda = k, ldb = k, ldd = m;
    std::unique_ptr<cublasLtGemm> gemm = std::make_unique<cublasLtGemm>();
    gemm->Init();
    gemm->Setup(m, n, k, batch, lda, ldb, ldd, get_datatype<Ta>(), get_datatype<Tb>(), get_datatype<Tout>(),
                CUBLAS_OP_T, CUBLAS_OP_N, CUBLASLT_EPILOGUE_DEFAULT);

    void *workspace = nullptr;
    size_t workspace_size = gemm->GetAlgorithm(1, 2 * m * n);
    cudaMalloc(&workspace, workspace_size);

    // timer
    float time;
    cudaEvent_t startTime, endTime;
    cudaEventCreate(&startTime);
    cudaEventCreate(&endTime);

    for (int i = 0; i < warmup; i++)
        gemm->Execute(reinterpret_cast<void *>(matrix_a), reinterpret_cast<void *>(matrix_b),
                      reinterpret_cast<void *>(matrix_out), reinterpret_cast<void *>(matrix_out), 1.f, 0.f, workspace,
                      workspace_size, 0);
    cudaEventRecord(startTime, 0);
    for (int i = 0; i < iter; i++)
        gemm->Execute(reinterpret_cast<void *>(matrix_a), reinterpret_cast<void *>(matrix_b),
                      reinterpret_cast<void *>(matrix_out), reinterpret_cast<void *>(matrix_out), 1.f, 0.f, workspace,
                      workspace_size, 0);
    cudaEventRecord(endTime, 0);
    cudaEventSynchronize(endTime);
    cudaEventElapsedTime(&time, startTime, endTime);

    // deallocate
    cudaFree(workspace);
    cudaFree(matrix_a);
    cudaFree(matrix_b);
    cudaFree(matrix_out);
    return (time * 1e3 / iter);
}

141
template <typename Ta, typename Tb = Ta, typename Tout = Ta> void run(Args *args) {
142
143
144
145
146
147
148
149
150
151
    float time_us = timing_matmul_tn<Ta, Tb, Tout>(args->m, args->n, args->k, args->batch, args->warmup, args->iter);
    // m n k batch time_us tflops
    printf("%d\t%d\t%d\t%d\t%f\t%f\n", args->m, args->n, args->k, args->batch, time_us,
           float(args->m) * float(args->n) * float(2 * args->k - 1) / 1e6 / time_us * std::max(args->batch, 1));
}

int main(int argc, char **argv) {
    Args args;
    process_args(argc, argv, &args);

152
153
154
155
156
    if (args.in_type == "fp64")
        run<fp64>(&args);
    else if (args.in_type == "fp32")
        run<fp32>(&args);
    else if (args.in_type == "fp16")
157
        run<fp16>(&args);
158
159
    else if (args.in_type == "bf16")
        run<bf16>(&args);
160
    else if (args.in_type == "fp8e4m3")
161
        run<fp8e4m3, fp8e4m3, fp16>(&args);
162
    else if (args.in_type == "fp8e5m2")
163
        run<fp8e5m2, fp8e4m3, fp16>(&args);
164
165
166
167
168
    else
        throw std::invalid_argument("Unknown type " + args.in_type);

    return 0;
}