"tools/vscode:/vscode.git/clone" did not exist on "4f1a5e52472c740e392b889f28021bc60a11b880"
test_groupnorm_fp16.cpp 2.25 KB
Newer Older
rocking5566's avatar
rocking5566 committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
rocking5566's avatar
rocking5566 committed
3
4

#include "gtest/gtest.h"
5
#include "profiler/profile_groupnorm_impl.hpp"
rocking5566's avatar
rocking5566 committed
6
7
8
9
10
11
12
13
14

using F16 = ck::half_t;
using F32 = float;
using ck::index_t;

template <typename Tuple>
class TestGroupnorm : public ::testing::Test
{
    protected:
15
16
17
18
19
20
    using XDataType              = std::tuple_element_t<0, Tuple>;
    using GammaDataType          = std::tuple_element_t<1, Tuple>;
    using BetaDataType           = std::tuple_element_t<2, Tuple>;
    using ComputeDataType        = std::tuple_element_t<3, Tuple>;
    using YDataType              = std::tuple_element_t<4, Tuple>;
    using SaveMeanInvStdDataType = std::tuple_element_t<5, Tuple>;
rocking5566's avatar
rocking5566 committed
21
22
23

    void Run()
    {
24
        // [N, H, W, G, C], reduce H, W, C
rocking5566's avatar
rocking5566 committed
25
26
27
28
29
        std::vector<std::vector<ck::index_t>> lengths = {{1, 1, 1, 1, 1},
                                                         {1, 2, 3, 4, 5},
                                                         {256, 9, 9, 9, 9},
                                                         {1, 64, 64, 32, 10},
                                                         {1, 32, 32, 32, 20},
30
31
                                                         {2, 32, 32, 32, 30},
                                                         {2, 32, 32, 32, 40},
rocking5566's avatar
rocking5566 committed
32
33
34
35
36
37
38
39
                                                         {1, 16, 16, 32, 40}};

        for(auto length : lengths)
        {
            bool success =
                ck::profiler::profile_groupnorm_impl<XDataType,
                                                     GammaDataType,
                                                     BetaDataType,
rocking5566's avatar
rocking5566 committed
40
                                                     ComputeDataType,
41
42
43
                                                     YDataType,
                                                     SaveMeanInvStdDataType,
                                                     true>(true, 2, false, false, length);
rocking5566's avatar
rocking5566 committed
44
45
46
47
48
49
            EXPECT_TRUE(success);
        }
    }
};

using KernelTypes = ::testing::Types<
rocking5566's avatar
rocking5566 committed
50
    // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
51
    std::tuple<F16, F16, F16, F32, F16, F32>>;
rocking5566's avatar
rocking5566 committed
52
53
54

TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
TYPED_TEST(TestGroupnorm, Test_FP16) { this->Run(); }