cublaslt_gemm.cu 5.34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
// Copyright(c) Microsoft Corporation.
// Licensed under the MIT License.

#include <getopt.h>
#include <memory>
#include <stdio.h>

#include <cuda_fp16.h>
#include <cuda_fp8.h>

#include "cublaslt_utils.h"

13
14
15
16
using fp64 = double;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;

struct Args {
    int m = 16;
    int n = 16;
    int k = 16;
    int batch = 0;
    int warmup = 20;
    int iter = 50;
    std::string in_type = "fp8e4m3";
};

void process_args(int argc, char **argv, Args *args) {
    const char *const short_opts = "m:n:k:b:w:i:t:";
    const option long_opts[] = {
        {"batch", required_argument, nullptr, 'b'},
        {"warmup", required_argument, nullptr, 'w'},
        {"iter", required_argument, nullptr, 'i'},
        {"in_type", required_argument, nullptr, 't'},
    };

    int opt = 0;
    while ((opt = getopt_long(argc, argv, short_opts, long_opts, nullptr)) != -1) {
        switch (opt) {
        case 'm':
            args->m = std::stoi(optarg);
            break;
        case 'n':
            args->n = std::stoi(optarg);
            break;
        case 'k':
            args->k = std::stoi(optarg);
            break;
        case 'b':
            args->batch = std::stoi(optarg);
            break;
        case 'w':
            args->warmup = std::stoi(optarg);
            break;
        case 'i':
            args->iter = std::stoi(optarg);
            break;
        case 't':
            args->in_type = std::string(optarg);
            break;
        }
    }
}

67
template <typename T> __global__ void init_matrix(T *matrix, const fp32 val, const size_t N) {
68
69
70
71
72
73
74
    size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) {
        matrix[i] = T(val);
    }
}

template <typename T> cudaDataType_t get_datatype() {
75
76
77
78
    if (std::is_same<T, fp64>::value)
        return CUDA_R_64F;
    if (std::is_same<T, fp32>::value)
        return CUDA_R_32F;
79
80
    if (std::is_same<T, fp16>::value)
        return CUDA_R_16F;
81
82
    if (std::is_same<T, bf16>::value)
        return CUDA_R_16BF;
83
84
85
86
87
88
89
90
    if (std::is_same<T, fp8e4m3>::value)
        return CUDA_R_8F_E4M3;
    if (std::is_same<T, fp8e5m2>::value)
        return CUDA_R_8F_E5M2;
    throw std::invalid_argument("Unknown type");
}

template <typename Ta, typename Tb, typename Tout>
91
float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, int iter) {
92
93
94
95
    // init matrix
    Ta *matrix_a = nullptr;
    Tb *matrix_b = nullptr;
    Tout *matrix_out = nullptr;
96
97
98
99
    batch = std::max<size_t>(batch, 1);
    cudaMalloc(&matrix_a, m * k * batch * sizeof(Ta));
    cudaMalloc(&matrix_b, k * n * batch * sizeof(Tb));
    cudaMalloc(&matrix_out, m * n * batch * sizeof(Tout));
100

101
102
    init_matrix<Ta><<<216, 1024>>>(matrix_a, 1.f, m * k * batch);
    init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * batch);
103
104

    // init gemm
105
    size_t lda = k, ldb = k, ldd = m;
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    std::unique_ptr<cublasLtGemm> gemm = std::make_unique<cublasLtGemm>();
    gemm->Init();
    gemm->Setup(m, n, k, batch, lda, ldb, ldd, get_datatype<Ta>(), get_datatype<Tb>(), get_datatype<Tout>(),
                CUBLAS_OP_T, CUBLAS_OP_N, CUBLASLT_EPILOGUE_DEFAULT);

    void *workspace = nullptr;
    size_t workspace_size = gemm->GetAlgorithm(1, 2 * m * n);
    cudaMalloc(&workspace, workspace_size);

    // timer
    float time;
    cudaEvent_t startTime, endTime;
    cudaEventCreate(&startTime);
    cudaEventCreate(&endTime);

    for (int i = 0; i < warmup; i++)
        gemm->Execute(reinterpret_cast<void *>(matrix_a), reinterpret_cast<void *>(matrix_b),
                      reinterpret_cast<void *>(matrix_out), reinterpret_cast<void *>(matrix_out), 1.f, 0.f, workspace,
                      workspace_size, 0);
    cudaEventRecord(startTime, 0);
    for (int i = 0; i < iter; i++)
        gemm->Execute(reinterpret_cast<void *>(matrix_a), reinterpret_cast<void *>(matrix_b),
                      reinterpret_cast<void *>(matrix_out), reinterpret_cast<void *>(matrix_out), 1.f, 0.f, workspace,
                      workspace_size, 0);
    cudaEventRecord(endTime, 0);
    cudaEventSynchronize(endTime);
    cudaEventElapsedTime(&time, startTime, endTime);

    // deallocate
    cudaFree(workspace);
    cudaFree(matrix_a);
    cudaFree(matrix_b);
    cudaFree(matrix_out);
    return (time * 1e3 / iter);
}

142
template <typename Ta, typename Tb = Ta, typename Tout = Ta> void run(Args *args) {
143
144
145
146
147
148
149
150
151
152
    float time_us = timing_matmul_tn<Ta, Tb, Tout>(args->m, args->n, args->k, args->batch, args->warmup, args->iter);
    // m n k batch time_us tflops
    printf("%d\t%d\t%d\t%d\t%f\t%f\n", args->m, args->n, args->k, args->batch, time_us,
           float(args->m) * float(args->n) * float(2 * args->k - 1) / 1e6 / time_us * std::max(args->batch, 1));
}

int main(int argc, char **argv) {
    Args args;
    process_args(argc, argv, &args);

153
154
155
156
157
    if (args.in_type == "fp64")
        run<fp64>(&args);
    else if (args.in_type == "fp32")
        run<fp32>(&args);
    else if (args.in_type == "fp16")
158
        run<fp16>(&args);
159
160
    else if (args.in_type == "bf16")
        run<bf16>(&args);
161
    else if (args.in_type == "fp8e4m3")
162
        run<fp8e4m3, fp8e4m3, fp16>(&args);
163
    else if (args.in_type == "fp8e5m2")
164
        run<fp8e5m2, fp8e4m3, fp16>(&args);
165
166
167
168
169
    else
        throw std::invalid_argument("Unknown type " + args.in_type);

    return 0;
}