profile_gemm_add_add_fastgelu.cpp 5.78 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>

Chao Liu's avatar
rename  
Chao Liu committed
7
#include "profile_gemm_add_add_fastgelu_impl.hpp"
Chao Liu's avatar
Chao Liu committed
8

Chao Liu's avatar
rename  
Chao Liu committed
9
int profile_gemm_add_add_fastgelu(int argc, char* argv[])
Chao Liu's avatar
Chao Liu committed
10
{
Chao Liu's avatar
Chao Liu committed
11
    enum struct MatrixLayout
Chao Liu's avatar
Chao Liu committed
12
    {
Chao Liu's avatar
Chao Liu committed
13
14
15
16
17
18
19
20
        MK_KN_MN_MN_MN, // 0
        MK_NK_MN_MN_MN, // 1
        KM_KN_MN_MN_MN, // 2
        KM_NK_MN_MN_MN, // 3
        MK_KN_NM_MN_MN, // 4
        MK_NK_NM_MN_MN, // 5
        KM_KN_NM_MN_MN, // 6
        KM_NK_NM_MN_MN, // 7
Chao Liu's avatar
Chao Liu committed
21
22
    };

Chao Liu's avatar
Chao Liu committed
23
    enum struct MatrixDataType
Chao Liu's avatar
Chao Liu committed
24
    {
Chao Liu's avatar
Chao Liu committed
25
26
27
28
        F32_F32_F32_F32_F32,         // 0
        F16_F16_F16_F16_F16_F16_F16, // 1
        BF16_BF16_BF16_BF16_BF16,    // 2
        INT8_INT8_INT8_INT8_INT8,    // 3
Chao Liu's avatar
Chao Liu committed
29
30
    };

Chao Liu's avatar
rename  
Chao Liu committed
31
    if(argc != 16)
Chao Liu's avatar
Chao Liu committed
32
    {
Chao Liu's avatar
rename  
Chao Liu committed
33
34
        // clang-format off
        printf("arg1: tensor operation (gemm_gelu: GEMM+Add+Add+GeLU)\n");
Chao Liu's avatar
Chao Liu committed
35
        printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
Chao Liu's avatar
rename  
Chao Liu committed
36
37
38
39
        printf("arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n] + D1[m, n]);\n");
        printf("                     1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n] + D1[m, n]);\n");
        printf("                     2: E[m, n] = FastGeLU(A[k, m] * B[k, n] + D0[m, n] + D1[m, n]);\n");
        printf("                     3: E[m, n] = FastGeLU(A[k, m] * B[n, k] + D0[m, n] + D1[m, n]))\n");
Chao Liu's avatar
Chao Liu committed
40
41
42
43
        printf("arg4: verification (0: no; 1: yes)\n");
        printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
        printf("arg6: print tensor value (0: no; 1: yes)\n");
        printf("arg7: time kernel (0=n0, 1=yes)\n");
Chao Liu's avatar
Chao Liu committed
44
        printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n");
Chao Liu's avatar
rename  
Chao Liu committed
45
        // clang-format on
Chao Liu's avatar
Chao Liu committed
46
47
48
        exit(1);
    }

Chao Liu's avatar
Chao Liu committed
49
50
    const auto data_type       = static_cast<MatrixDataType>(std::stoi(argv[2]));
    const auto layout          = static_cast<MatrixLayout>(std::stoi(argv[3]));
Chao Liu's avatar
Chao Liu committed
51
52
53
54
55
56
57
58
59
    const bool do_verification = std::stoi(argv[4]);
    const int init_method      = std::stoi(argv[5]);
    const bool do_log          = std::stoi(argv[6]);
    const bool time_kernel     = std::stoi(argv[7]);

    const int M = std::stoi(argv[8]);
    const int N = std::stoi(argv[9]);
    const int K = std::stoi(argv[10]);

Chao Liu's avatar
rename  
Chao Liu committed
60
61
    const int StrideA  = std::stoi(argv[11]);
    const int StrideB  = std::stoi(argv[12]);
Chao Liu's avatar
Chao Liu committed
62
63
64
    const int StrideD0 = std::stoi(argv[13]);
    const int StrideD1 = std::stoi(argv[14]);
    const int StrideE  = std::stoi(argv[15]);
Chao Liu's avatar
Chao Liu committed
65
66
67
68
69
70

    using F16 = ck::half_t;

    using Row = ck::tensor_layout::gemm::RowMajor;
    using Col = ck::tensor_layout::gemm::ColumnMajor;

Chao Liu's avatar
Chao Liu committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    auto profile = [&](auto a_type,
                       auto b_type,
                       auto d0_type,
                       auto d1_type,
                       auto e_type,
                       auto a_layout,
                       auto b_layout,
                       auto d0_layout,
                       auto d1_layout,
                       auto e_layout) {
        using ADataType  = decltype(a_type);
        using BDataType  = decltype(b_type);
        using D0DataType = decltype(d0_type);
        using D1DataType = decltype(d1_type);
        using EDataType  = decltype(e_type);
Chao Liu's avatar
Chao Liu committed
86

Chao Liu's avatar
Chao Liu committed
87
88
89
90
91
        using ALayout  = decltype(a_layout);
        using BLayout  = decltype(b_layout);
        using D0Layout = decltype(d0_layout);
        using D1Layout = decltype(d1_layout);
        using ELayout  = decltype(e_layout);
Chao Liu's avatar
Chao Liu committed
92

Chao Liu's avatar
Chao Liu committed
93
94
95
96
97
        const int DefaultStrideA  = ck::is_same_v<ALayout, Row> ? K : M;
        const int DefaultStrideB  = ck::is_same_v<BLayout, Row> ? N : K;
        const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
        const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? N : M;
        const int DefaultStrideE  = ck::is_same_v<ELayout, Row> ? N : M;
Chao Liu's avatar
Chao Liu committed
98

Chao Liu's avatar
Chao Liu committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        return ck::profiler::profile_gemm_add_add_gelu_impl<ADataType,
                                                            BDataType,
                                                            D0DataType,
                                                            D1DataType,
                                                            EDataType,
                                                            ALayout,
                                                            BLayout,
                                                            D0Layout,
                                                            D1Layout,
                                                            ELayout>(
            do_verification,
            init_method,
            do_log,
            time_kernel,
            M,
            N,
            K,
            (StrideA < 0) ? DefaultStrideA : StrideA,
            (StrideB < 0) ? DefaultStrideB : StrideB,
            (StrideD0 < 0) ? DefaultStrideD0 : StrideD0,
            (StrideD1 < 0) ? DefaultStrideD1 : StrideD1,
            (StrideE < 0) ? DefaultStrideE : StrideE);
    };

    if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN)
Chao Liu's avatar
Chao Liu committed
124
    {
Chao Liu's avatar
Chao Liu committed
125
        return profile(F16{}, F16{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
126
    }
Chao Liu's avatar
Chao Liu committed
127
128
    else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
            layout == MatrixLayout::MK_NK_MN_MN_MN)
Chao Liu's avatar
Chao Liu committed
129
    {
Chao Liu's avatar
Chao Liu committed
130
        return profile(F16{}, F16{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
131
    }
Chao Liu's avatar
Chao Liu committed
132
133
    else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
            layout == MatrixLayout::KM_KN_MN_MN_MN)
Chao Liu's avatar
Chao Liu committed
134
    {
Chao Liu's avatar
Chao Liu committed
135
        return profile(F16{}, F16{}, F16{}, F16{}, F16{}, Col{}, Row{}, Row{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
136
    }
Chao Liu's avatar
Chao Liu committed
137
138
    else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
            layout == MatrixLayout::KM_NK_MN_MN_MN)
Chao Liu's avatar
Chao Liu committed
139
    {
Chao Liu's avatar
Chao Liu committed
140
        return profile(F16{}, F16{}, F16{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
141
142
143
144
145
146
147
148
    }
    else
    {
        std::cout << "this data_type & layout is not implemented" << std::endl;

        return 0;
    }
}