test_layernorm2d_fp32.cpp 1.63 KB
Newer Older
rocking5566's avatar
rocking5566 committed
1
2
3
4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include "gtest/gtest.h"
5
#include "profiler/profile_layernorm_impl.hpp"
rocking5566's avatar
rocking5566 committed
6

7
8
9
using F16 = ck::half_t;
using F32 = float;
using ck::index_t;
rocking5566's avatar
rocking5566 committed
10
11

template <typename Tuple>
12
class TestLayernorm2d : public ::testing::Test
rocking5566's avatar
rocking5566 committed
13
{
14
    protected:
rocking5566's avatar
rocking5566 committed
15
16
17
18
19
    using XDataType       = std::tuple_element_t<0, Tuple>;
    using GammaDataType   = std::tuple_element_t<1, Tuple>;
    using BetaDataType    = std::tuple_element_t<2, Tuple>;
    using ComputeDataType = std::tuple_element_t<3, Tuple>;
    using YDataType       = std::tuple_element_t<4, Tuple>;
20
21
22
23
24
25
26
27
28
29
30
31

    void Run()
    {
        // [N, D], reduce D
        std::vector<std::vector<ck::index_t>> lengths = {
            {4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}};

        for(auto length : lengths)
        {
            bool success = ck::profiler::profile_layernorm_impl<XDataType,
                                                                GammaDataType,
                                                                BetaDataType,
rocking5566's avatar
rocking5566 committed
32
                                                                ComputeDataType,
33
34
35
36
37
                                                                YDataType,
                                                                2>(true, 2, false, false, length);
            EXPECT_TRUE(success);
        }
    }
rocking5566's avatar
rocking5566 committed
38
39
40
};

using KernelTypes = ::testing::Types<
rocking5566's avatar
rocking5566 committed
41
    // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
42
43
44
45
    std::tuple<F32, F32, F32, F32, F32>>;

TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes);
TYPED_TEST(TestLayernorm2d, Test_FP32) { this->Run(); }