test_max_pool2d_bwd.cpp 4.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#include "gtest/gtest.h"
#include "profiler/profile_max_pool2d_bwd_impl.hpp"
#include "test_pool_fwd_common.hpp"

template <typename T>
class MaxPool2dBWDTest : public ::testing::Test
{
    protected:
    using DOutDataType  = std::tuple_element_t<0, T>;
    using DInDataType   = std::tuple_element_t<1, T>;
    using IndexDataType = std::tuple_element_t<2, T>;

    using InDataType  = DInDataType;
    using OutDataType = DOutDataType;

    static std::vector<PoolingParam> params;

    void Run()
    {
        for(auto param : this->params)
        {
            bool success =
                ck::profiler::profile_max_pool2d_bwd_impl<InDataType,
                                                          OutDataType,
                                                          IndexDataType,
                                                          DOutDataType,
                                                          DInDataType,
                                                          false>(true,
                                                                 2,
                                                                 false,
                                                                 false,
                                                                 param.length_,
                                                                 param.window_spatial_lengths_,
                                                                 param.window_strides_,
                                                                 param.window_dilations_,
                                                                 param.input_left_pads_,
                                                                 param.input_right_pads_);
            EXPECT_TRUE(success);
        }
    }
};

template <typename T>
std::vector<PoolingParam> MaxPool2dBWDTest<T>::params = {
    {{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
    {{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
    {{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
    {{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}},
    {{2, 2, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}};

using Max_Pool_2D_f32_types  = ::testing::Types<std::tuple<F32, F32, I32>>;
using Max_Pool_2D_int8_types = ::testing::Types<std::tuple<I8, I8, I32>>;
using Max_Pool_2D_f16_types  = ::testing::Types<std::tuple<F16, F16, I32>>;
using Max_Pool_2D_bf16_types = ::testing::Types<std::tuple<BF16, BF16, I32>>;
58
using Max_Pool_2D_f8_types   = ::testing::Types<std::tuple<F8, F8, I32>>;
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

template <typename TType>
class MaxPool2D_f32 : public MaxPool2dBWDTest<TType>
{
    protected:
    void SetUp() override
    {
        if(!CK_ENABLE_FP32)
        {
            GTEST_SKIP() << "Skipping MaxPool2D_f32 tests because CK_ENABLE_FP32 is not enabled";
        }
    }
};

template <typename TType>
class MaxPool2D_int8 : public MaxPool2dBWDTest<TType>
{
    protected:
    void SetUp() override
    {
        if(!CK_ENABLE_INT8)
        {
            GTEST_SKIP() << "Skipping MaxPool2D_int8 tests because CK_ENABLE_INT8 is not enabled";
        }
    }
};

template <typename TType>
class MaxPool2D_f16 : public MaxPool2dBWDTest<TType>
{
    protected:
    void SetUp() override
    {
        if(!CK_ENABLE_FP16)
        {
            GTEST_SKIP() << "Skipping MaxPool2D_f16 because CK_ENABLE_FP16 is not enabled";
        }
    }
};

template <typename TType>
class MaxPool2D_bf16 : public MaxPool2dBWDTest<TType>
{
    protected:
    void SetUp() override
    {
        if(!CK_ENABLE_BF16)
        {
            GTEST_SKIP() << "Skipping MaxPool2D_bf16 tests because CK_ENABLE_BF16 is not enabled";
        }
    }
};

112
113
114
115
116
117
118
119
120
121
122
123
124
template <typename TType>
class MaxPool2D_f8 : public MaxPool2dBWDTest<TType>
{
    protected:
    void SetUp() override
    {
        if(!CK_ENABLE_FP8)
        {
            GTEST_SKIP() << "Skipping MaxPool2D_f8 tests because CK_ENABLE_FP8 is not enabled";
        }
    }
};

125
126
127
128
TYPED_TEST_SUITE(MaxPool2D_f32, Max_Pool_2D_f32_types);
TYPED_TEST_SUITE(MaxPool2D_int8, Max_Pool_2D_int8_types);
TYPED_TEST_SUITE(MaxPool2D_f16, Max_Pool_2D_f16_types);
TYPED_TEST_SUITE(MaxPool2D_bf16, Max_Pool_2D_bf16_types);
129
TYPED_TEST_SUITE(MaxPool2D_f8, Max_Pool_2D_f8_types);
130
131
132
133
134
135
136
137

TYPED_TEST(MaxPool2D_f32, MaxPool2DTest_f32) { this->Run(); }

TYPED_TEST(MaxPool2D_int8, MaxPool2DTest_int8) { this->Run(); }

TYPED_TEST(MaxPool2D_f16, MaxPool2DTest_f16) { this->Run(); }

TYPED_TEST(MaxPool2D_bf16, MaxPool2DTest_bf16) { this->Run(); }
138
139

TYPED_TEST(MaxPool2D_f8, MaxPool2DTest_f8) { this->Run(); }