test_batched_gemm.cpp 8.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>

#include <gtest/gtest.h>

#include "profiler/profile_batched_gemm_impl.hpp"

#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"

struct GemmParams
{
    ck::index_t M;
    ck::index_t N;
    ck::index_t K;
    ck::index_t BatchCount;
};

class TestBatchedGemm : public ::testing::Test
{
    protected:
    using Row = ck::tensor_layout::gemm::RowMajor;
    using Col = ck::tensor_layout::gemm::ColumnMajor;

    using PassThrough = ck::tensor_operation::element_wise::PassThrough;

    std::vector<GemmParams> params;

    template <typename DataType>
    void Run()
    {
        using namespace ck::tensor_operation::device;

        bool pass = true;
        for(auto& param : params)
        {
            const auto M          = param.M;
            const auto N          = param.N;
            const auto K          = param.K;
            const auto BatchCount = param.BatchCount;

            pass =
                pass && ck::profiler::profile_batched_gemm_impl<DataType,
                                                                DataType,
                                                                DataType,
                                                                Row,
                                                                Row,
                                                                Row,
                                                                PassThrough,
                                                                PassThrough,
                                                                PassThrough,
                                                                DeviceBatchedGemm<Row,
                                                                                  Row,
                                                                                  Row,
                                                                                  DataType,
                                                                                  DataType,
                                                                                  DataType,
                                                                                  PassThrough,
                                                                                  PassThrough,
                                                                                  PassThrough>>(
                            true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);

            pass =
                pass && ck::profiler::profile_batched_gemm_impl<DataType,
                                                                DataType,
                                                                DataType,
                                                                Row,
                                                                Col,
                                                                Row,
                                                                PassThrough,
                                                                PassThrough,
                                                                PassThrough,
                                                                DeviceBatchedGemm<Row,
                                                                                  Col,
                                                                                  Row,
                                                                                  DataType,
                                                                                  DataType,
                                                                                  DataType,
                                                                                  PassThrough,
                                                                                  PassThrough,
                                                                                  PassThrough>>(
                            true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);

            pass =
                pass && ck::profiler::profile_batched_gemm_impl<DataType,
                                                                DataType,
                                                                DataType,
                                                                Col,
                                                                Row,
                                                                Row,
                                                                PassThrough,
                                                                PassThrough,
                                                                PassThrough,
                                                                DeviceBatchedGemm<Col,
                                                                                  Row,
                                                                                  Row,
                                                                                  DataType,
                                                                                  DataType,
                                                                                  DataType,
                                                                                  PassThrough,
                                                                                  PassThrough,
                                                                                  PassThrough>>(
                            true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);

            pass =
                pass && ck::profiler::profile_batched_gemm_impl<DataType,
                                                                DataType,
                                                                DataType,
                                                                Col,
                                                                Col,
                                                                Row,
                                                                PassThrough,
                                                                PassThrough,
                                                                PassThrough,
                                                                DeviceBatchedGemm<Col,
                                                                                  Col,
                                                                                  Row,
                                                                                  DataType,
                                                                                  DataType,
                                                                                  DataType,
                                                                                  PassThrough,
                                                                                  PassThrough,
                                                                                  PassThrough>>(
                            true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
        }
        EXPECT_TRUE(pass);
    }
};

#ifdef CK_ENABLE_INT8
TEST_F(TestBatchedGemm, i8)
{
    this->params.push_back({64, 64, 64, 2});
    this->params.push_back({64, 64, 64, 1});
    this->params.push_back({60, 60, 60, 2});
    this->params.push_back({68, 68, 68, 2});
    this->params.push_back({40, 40, 40, 2});
    this->params.push_back({256, 256, 128, 3});
    this->template Run<int8_t>();
}
#endif

#ifdef CK_ENABLE_BF16
TEST_F(TestBatchedGemm, bf16)
{
    this->params.push_back({64, 64, 64, 2});
    this->params.push_back({64, 64, 64, 1});
    this->params.push_back({60, 60, 60, 2});
    this->params.push_back({68, 68, 68, 2});
    this->params.push_back({40, 40, 40, 2});
    this->params.push_back({256, 256, 128, 3});
    this->template Run<ck::bhalf_t>();
}
#endif

#ifdef CK_ENABLE_FP16
TEST_F(TestBatchedGemm, fp16)
{
    this->params.push_back({64, 64, 64, 2});
    this->params.push_back({64, 64, 64, 1});
    this->params.push_back({60, 60, 60, 2});
    this->params.push_back({68, 68, 68, 2});
    this->params.push_back({40, 40, 40, 2});
    this->params.push_back({256, 256, 128, 3});
    this->template Run<ck::half_t>();
}
#endif

#ifdef CK_ENABLE_FP32
TEST_F(TestBatchedGemm, fp32)
{
    this->params.push_back({64, 64, 64, 2});
    this->params.push_back({64, 64, 64, 1});
    this->params.push_back({60, 60, 60, 2});
    this->params.push_back({68, 68, 68, 2});
    this->params.push_back({40, 40, 40, 2});
    this->params.push_back({256, 256, 128, 3});
    this->template Run<float>();
}
#endif