test_avg_pool2d_bwd.cpp 3.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#include "gtest/gtest.h"
#include "profiler/profile_avg_pool2d_bwd_impl.hpp"
#include "test_pool_fwd_common.hpp"

template <typename T>
class AvgPool2dBWDTest : public ::testing::Test
{
    protected:
    using InDataType  = std::tuple_element_t<0, T>;
    using OutDataType = std::tuple_element_t<1, T>;

    static std::vector<PoolingParam> params;

    void Run()
    {
        for(auto param : this->params)
        {
            bool success =
                ck::profiler::profile_avg_pool2d_bwd_impl<InDataType, OutDataType, NHWC, NHWC>(
                    true,
                    2,
                    false,
                    false,
                    param.length_,
                    param.window_spatial_lengths_,
                    param.window_strides_,
                    param.window_dilations_,
                    param.input_left_pads_,
                    param.input_right_pads_);
            EXPECT_TRUE(success);
        }
    }
};

template <typename T>
std::vector<PoolingParam> AvgPool2dBWDTest<T>::params = {
    {{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
    {{1, 1, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
    {{1, 5, 7, 7}, {2, 2}, {2, 2}, {1, 1}, {2, 2}, {0, 0}},
    {{1, 1, 8, 8}, {2, 2}, {2, 2}, {1, 1}, {2, 2}, {0, 0}},
    {{1, 1, 8, 8}, {2, 2}, {1, 1}, {1, 1}, {1, 1}, {0, 0}},
    {{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}},
    {{1, 2, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}};

using Avg_Pool_2D_f32_types  = ::testing::Types<std::tuple<F32, F32>>;
using Avg_Pool_2D_int8_types = ::testing::Types<std::tuple<I8, I8>>;
using Avg_Pool_2D_f16_types  = ::testing::Types<std::tuple<F16, F16>>;
using Avg_Pool_2D_bf16_types = ::testing::Types<std::tuple<BF16, BF16>>;
using Avg_Pool_2D_f8_types   = ::testing::Types<std::tuple<F8, F8>>;

template <typename TType>
class AvgPool2D_f32 : public AvgPool2dBWDTest<TType>
{
    protected:
    void SetUp() override
    {
        if(!CK_ENABLE_FP32)
        {
            GTEST_SKIP() << "Skipping AvgPool2D_f32 tests because CK_ENABLE_FP32 is not enabled";
        }
    }
};

template <typename TType>
class AvgPool2D_int8 : public AvgPool2dBWDTest<TType>
{
    protected:
    void SetUp() override
    {
        if(!CK_ENABLE_INT8)
        {
            GTEST_SKIP() << "Skipping AvgPool2D_int8 tests because CK_ENABLE_INT8 is not enabled";
        }
    }
};

template <typename TType>
class AvgPool2D_f16 : public AvgPool2dBWDTest<TType>
{
    protected:
    void SetUp() override
    {
        if(!CK_ENABLE_FP16)
        {
            GTEST_SKIP() << "Skipping AvgPool2D_f16 because CK_ENABLE_FP16 is not enabled";
        }
    }
};

template <typename TType>
class AvgPool2D_bf16 : public AvgPool2dBWDTest<TType>
{
    protected:
    void SetUp() override
    {
        if(!CK_ENABLE_BF16)
        {
            GTEST_SKIP() << "Skipping AvgPool2D_bf16 tests because CK_ENABLE_BF16 is not enabled";
        }
    }
};

template <typename TType>
class AvgPool2D_f8 : public AvgPool2dBWDTest<TType>
{
    protected:
    void SetUp() override
    {
        if(!CK_ENABLE_FP8)
        {
            GTEST_SKIP() << "Skipping AvgPool2D_f8 tests because CK_ENABLE_FP8 is not enabled";
        }
    }
};

TYPED_TEST_SUITE(AvgPool2D_f32, Avg_Pool_2D_f32_types);
TYPED_TEST_SUITE(AvgPool2D_int8, Avg_Pool_2D_int8_types);
TYPED_TEST_SUITE(AvgPool2D_f16, Avg_Pool_2D_f16_types);
TYPED_TEST_SUITE(AvgPool2D_bf16, Avg_Pool_2D_bf16_types);
TYPED_TEST_SUITE(AvgPool2D_f8, Avg_Pool_2D_f8_types);

TYPED_TEST(AvgPool2D_f32, AvgPool2DTest_f32) { this->Run(); }

TYPED_TEST(AvgPool2D_int8, AvgPool2DTest_int8) { this->Run(); }

TYPED_TEST(AvgPool2D_f16, AvgPool2DTest_f16) { this->Run(); }

TYPED_TEST(AvgPool2D_bf16, AvgPool2DTest_bf16) { this->Run(); }

TYPED_TEST(AvgPool2D_f8, AvgPool2DTest_f8) { this->Run(); }