test_contraction_xdl.cpp 9.64 KB
Newer Older
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6
7
8
9
10
11
12

#include <cstdlib>
#include <iostream>
#include <memory>
#include <initializer_list>
#include <vector>
#include <tuple>
#include <gtest/gtest.h>

#include "profiler/profile_contraction_impl.hpp"
13
#include "profiler/profile_contraction_utils.hpp"
14

15
16
17
18
using F16  = ck::half_t;
using BF16 = ck::bhalf_t;
using F32  = float;
using F64  = double;
19
20
21
22
23
24
25

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale    = ck::tensor_operation::element_wise::Scale;

26
template <ck::index_t NDims>
27
struct Dimensions
28
{
29
30
    constexpr static ck::index_t NumDimMNK = NDims;

31
32
33
34
35
36
37
38
39
    std::vector<ck::index_t> M;
    std::vector<ck::index_t> N;
    std::vector<ck::index_t> K;
};

template <typename Tuple>
class TestContraction : public ::testing::Test
{
    protected:
40
41
42
43
44
45
46
47
48
    using ALayout         = std::tuple_element_t<0, Tuple>;
    using BLayout         = std::tuple_element_t<1, Tuple>;
    using CDLayout        = std::tuple_element_t<2, Tuple>;
    using DataType        = std::tuple_element_t<3, Tuple>;
    using DTupleDataType  = std::tuple_element_t<4, Tuple>;
    using ComputeDataType = std::tuple_element_t<5, Tuple>;
    using CDElementOp     = std::tuple_element_t<6, Tuple>;

    std::vector<ck::index_t> init_methods = {1, 2};
49
    std::unique_ptr<CDElementOp> p_cd_element_op;
50

51
52
    template <ck::index_t NumDim>
    void Run(Dimensions<NumDim> dimension_params)
53
    {
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        constexpr ck::index_t NumDimMNK = ck::remove_cvref_t<decltype(dimension_params)>::NumDimMNK;

        std::vector<ck::index_t> StridesA(2 * NumDim);
        std::vector<ck::index_t> StridesB(2 * NumDim);
        std::vector<ck::index_t> StridesC(2 * NumDim);
        std::vector<ck::index_t> StridesD(2 * NumDim);

        const auto& M = dimension_params.M;
        const auto& N = dimension_params.N;
        const auto& K = dimension_params.K;

        auto merge_dims = [](const std::vector<ck::index_t>& dims01,
                             const std::vector<ck::index_t>& dims23) {
            std::vector<ck::index_t> dims_szt(dims01.begin(), dims01.end());
            dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end());
            return dims_szt;
        };

        assign_default_strides(ALayout{}, StridesA, merge_dims(M, K));
        assign_default_strides(BLayout{}, StridesB, merge_dims(N, K));
        assign_default_strides(CDLayout{}, StridesC, merge_dims(M, N));
        assign_default_strides(CDLayout{}, StridesD, merge_dims(M, N));

        for(const ck::index_t init_method : init_methods)
78
        {
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            bool pass =
                ck::profiler::profile_contraction_impl<NumDimMNK,
                                                       ALayout,
                                                       BLayout,
                                                       CDLayout,
                                                       DataType,
                                                       ComputeDataType,
                                                       DTupleDataType,
                                                       CDElementOp>(true /*do_verification*/,
                                                                    init_method,
                                                                    false /*do_logs*/,
                                                                    false /*time_kernel*/,
                                                                    *p_cd_element_op,
                                                                    dimension_params.M,
                                                                    dimension_params.N,
                                                                    dimension_params.K,
                                                                    StridesA,
                                                                    StridesB,
                                                                    StridesC,
                                                                    StridesD);
            EXPECT_TRUE(pass);
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        }
    }
};

template <typename Tuple>
class TestContractionScale : public TestContraction<Tuple>
{
};

template <typename Tuple>
class TestContractionBilinear : public TestContraction<Tuple>
{
};

114
115
116
117
118
119
#define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op)    \
    std::tuple<Row, Row, Row, dt, tuple_dt, compute_dt, op>,     \
        std::tuple<Row, Col, Row, dt, tuple_dt, compute_dt, op>, \
        std::tuple<Col, Row, Row, dt, tuple_dt, compute_dt, op>, \
        std::tuple<Col, Col, Row, dt, tuple_dt, compute_dt, op>

120
using BilinearKernelTypes =
121
122
123
124
125
    ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F32, Bilinear),
                     ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F64, Bilinear)>;

using ScaleKernelTypes = ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F32, Scale),
                                          ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale)>;
126
127
128
129
130
131
132

TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes);
TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);

TYPED_TEST(TestContractionBilinear, bilinear)
{
    this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
133
134
135
136
137
    this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
    this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
    this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
    this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});

138
    this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
139
140
141
142
    this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
    this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
    this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
    this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
143
144
145
146
147
}

TYPED_TEST(TestContractionScale, scale)
{
    this->p_cd_element_op = std::make_unique<Scale>(1.f);
148
149
150
151
152
    this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
    this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
    this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
    this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});

153
    this->p_cd_element_op = std::make_unique<Scale>(0.5f);
154
155
156
157
    this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
    this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
    this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
    this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
158
}
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

template <typename Tuple>
class TestContractionScaleMixedPrecision : public TestContraction<Tuple>
{
};

template <typename Tuple>
class TestContractionBilinearMixedPrecision : public TestContraction<Tuple>
{
};

using BilinearKernelTypesMixedPrecision =
    ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F16, Bilinear),
                     ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, BF16, Bilinear),
                     ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F32, Bilinear),
                     ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<F16>, F32, Bilinear),
                     ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<BF16>, F32, Bilinear)>;

using ScaleKernelTypesMixedPrecision =
    ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F16, Scale),
                     ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, BF16, Scale),
                     ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F32, Scale),
                     ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<>, F32, Scale),
                     ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<>, F32, Scale)>;

TYPED_TEST_SUITE(TestContractionBilinearMixedPrecision, BilinearKernelTypesMixedPrecision);
TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecision);

TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear)
{
    this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
190
191
192
193
194
    this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
    this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
    this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
    this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});

195
    this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
196
197
198
199
    this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
    this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
    this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
    this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
200
201
202
203
204
}

TYPED_TEST(TestContractionScaleMixedPrecision, scale)
{
    this->p_cd_element_op = std::make_unique<Scale>(1.f);
205
206
207
208
209
    this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
    this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
    this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
    this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});

210
    this->p_cd_element_op = std::make_unique<Scale>(0.5f);
211
212
213
214
    this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
    this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
    this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
    this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
215
216
217
218
219
220

    // special cases
    this->template Run<2>({{1, 1}, {16, 8}, {8, 16}});
    this->template Run<2>({{8, 16}, {16, 8}, {1, 1}});
    this->template Run<2>({{8, 16}, {1, 1}, {8, 16}});
    this->template Run<2>({{1, 1}, {1, 1}, {1, 1}});
221
}