"src/include/ConstantTensorDescriptor.cuh" did not exist on "3dbd47252c860af79aa93f66ac19c891cb347dec"
test_groupnorm_fp16.cpp 2.23 KB
Newer Older
rocking's avatar
rocking committed
1
2
3
4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include "gtest/gtest.h"
5
#include "profiler/include/profile_groupnorm_impl.hpp"
rocking's avatar
rocking committed
6

7
8
9
10
using F16 = ck::half_t;
using F32 = float;
using ck::index_t;
using ck::profiler::ElementwiseOpEnum;
rocking's avatar
rocking committed
11
12

template <typename Tuple>
13
class TestGroupnorm : public ::testing::Test
rocking's avatar
rocking committed
14
{
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    protected:
    using XDataType     = std::tuple_element_t<0, Tuple>;
    using GammaDataType = std::tuple_element_t<1, Tuple>;
    using BetaDataType  = std::tuple_element_t<2, Tuple>;
    using AccDataType   = std::tuple_element_t<3, Tuple>;
    using YDataType     = std::tuple_element_t<4, Tuple>;

    void Run()
    {
        // N, H, W, G, C
        std::vector<std::vector<ck::index_t>> lengths = {{1, 1, 1, 1, 1},
                                                         {1, 2, 3, 4, 5},
                                                         {256, 9, 9, 9, 9},
                                                         {1, 64, 64, 32, 10},
                                                         {1, 32, 32, 32, 20},
                                                         {1, 16, 16, 32, 40}};

        for(auto length : lengths)
        {
            bool success = ck::profiler::profile_groupnorm_impl<XDataType,
                                                                GammaDataType,
                                                                BetaDataType,
                                                                AccDataType,
                                                                YDataType>(
                true, 2, false, false, length, ElementwiseOpEnum::eSigmoid);
            EXPECT_TRUE(success);
        }
    }
rocking's avatar
rocking committed
43
44
45
};

using KernelTypes = ::testing::Types<
46
47
48
49
50
51
52
53
54
55
56
57
    // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType>
    std::tuple<F16, F16, F16, F32, F16>,
    std::tuple<F16, F16, F16, F32, F16>,
    std::tuple<F16, F16, F16, F32, F16>,
    std::tuple<F16, F16, F16, F32, F16>,
    std::tuple<F16, F16, F16, F32, F16>,
    std::tuple<F16, F16, F16, F32, F16>,
    std::tuple<F16, F16, F16, F32, F16>,
    std::tuple<F16, F16, F16, F32, F16>>;

TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
TYPED_TEST(TestGroupnorm, Test_FP16) { this->Run(); }