test_grouped_gemm_util.hpp 2.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "include/ck/utility/data_type.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"

namespace ck {
namespace test {

20
21
22
23
24
25
26
27
28
29
30
31
template <typename Range>
std::string serialize_range(const Range& range)
{
    std::stringstream ss;
    for(auto& r : range)
    {
        ss << r << ", ";
    }
    std::string str = ss.str();
    return std::string(str.begin(), str.end() - 2);
}

32
33
34
35
36
37
38
39
40
41
42
43
44
template <typename Tuple>
class TestGroupedGemm : public testing::TestWithParam<int>
{
    protected:
    using ALayout   = std::tuple_element_t<0, Tuple>;
    using BLayout   = std::tuple_element_t<1, Tuple>;
    using ELayout   = std::tuple_element_t<2, Tuple>;
    using ADataType = std::tuple_element_t<3, Tuple>;
    using BDataType = std::tuple_element_t<4, Tuple>;
    using EDataType = std::tuple_element_t<5, Tuple>;

    public:
    bool verify_     = true;
45
    int init_method_ = 0; // decimal value initialization
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    bool log_        = false;
    bool bench_      = false; // measure kernel performance

    void SetUp() override {}

    void Run(const std::vector<int>& Ms,
             const std::vector<int>& Ns,
             const std::vector<int>& Ks,
             const std::vector<int>& StrideAs,
             const std::vector<int>& StrideBs,
             const std::vector<int>& StrideCs,
             int kbatch = 1)
    {
        bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
                                                            BDataType,
                                                            EDataType,
                                                            float,
                                                            ALayout,
                                                            BLayout,
                                                            ELayout>(
            verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch);
        EXPECT_TRUE(pass);
    }
};

} // namespace test
} // namespace ck