test_permute_scale.cpp 3.23 KB
Newer Older
arai713's avatar
arai713 committed
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
arai713's avatar
arai713 committed
3
4

#include "gtest/gtest.h"
5
#include "profiler/profile_permute_scale_impl.hpp"
arai713's avatar
arai713 committed
6
7
8
9
10
11
12
13
14
15
16
17

using F16 = ck::half_t;
using F32 = float;
using ck::index_t;

template <typename Tuple>
class TestPermute : public ::testing::Test
{
    protected:
    using ADataType = std::tuple_element_t<0, Tuple>;
    using BDataType = std::tuple_element_t<1, Tuple>;

18
    constexpr bool skip_case()
arai713's avatar
arai713 committed
19
    {
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#ifndef CK_ENABLE_FP16
        if constexpr(ck::is_same_v<ADataType, F16> || ck::is_same_v<BDataType, F16>)
        {
            return true;
        }
#endif
#ifndef CK_ENABLE_FP32
        if constexpr(ck::is_same_v<ADataType, F32> || ck::is_same_v<BDataType, F32>)
        {
            return true;
        }
#endif
        return false;
    }
arai713's avatar
arai713 committed
34

35
36
37
38
39
40
    template <ck::index_t NDims>
    void Run(std::vector<ck::index_t> lengths,
             std::vector<ck::index_t> input_strides,
             std::vector<ck::index_t> output_strides)
    {
        if(!skip_case())
arai713's avatar
arai713 committed
41
        {
42
43
            bool success = ck::profiler::profile_permute_scale_impl<ADataType, BDataType, NDims>(
                true, 2, false, false, lengths, input_strides, output_strides);
arai713's avatar
arai713 committed
44
45
46
47
48
49
50
51
            EXPECT_TRUE(success);
        }
    }
};

using KernelTypes = ::testing::Types<std::tuple<F16, F16>, std::tuple<F32, F32>>;

TYPED_TEST_SUITE(TestPermute, KernelTypes);
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
TYPED_TEST(TestPermute, Test1D)
{
    constexpr ck::index_t NumDims = 1;
    this->template Run<NumDims>({8}, {1}, {2});
    this->template Run<NumDims>({8}, {2}, {1});
    this->template Run<NumDims>({1}, {1}, {1});
}

TYPED_TEST(TestPermute, Test2D)
{
    constexpr ck::index_t NumDims = 2;
    this->template Run<NumDims>({8, 4}, {4, 1}, {1, 8});
    this->template Run<NumDims>({8, 4}, {1, 8}, {4, 1});
    this->template Run<NumDims>({1, 1}, {1, 1}, {1, 1});
}

TYPED_TEST(TestPermute, Test3D)
{
    constexpr ck::index_t NumDims = 3;
    this->template Run<NumDims>({2, 4, 4}, {16, 4, 1}, {1, 2, 8});
    this->template Run<NumDims>({2, 4, 4}, {1, 2, 8}, {16, 4, 1});
    this->template Run<NumDims>({1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}

TYPED_TEST(TestPermute, Test4D)
{
    constexpr ck::index_t NumDims = 4;
    this->template Run<NumDims>({2, 4, 4, 4}, {64, 16, 4, 1}, {1, 2, 8, 32});
    this->template Run<NumDims>({2, 4, 4, 4}, {1, 2, 8, 32}, {64, 16, 4, 1});
    this->template Run<NumDims>({1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1});
}

TYPED_TEST(TestPermute, Test5D)
{
    constexpr ck::index_t NumDims = 5;
    this->template Run<NumDims>({2, 4, 4, 4, 4}, {256, 64, 16, 4, 1}, {1, 2, 8, 32, 128});
    this->template Run<NumDims>({2, 4, 4, 4, 4}, {1, 2, 8, 32, 128}, {256, 64, 16, 4, 1});
    this->template Run<NumDims>({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1});
}

TYPED_TEST(TestPermute, Test6D)
{
    constexpr ck::index_t NumDims = 6;
    this->template Run<NumDims>(
        {2, 4, 4, 4, 4, 4}, {1024, 256, 64, 16, 4, 1}, {1, 2, 8, 32, 128, 512});
    this->template Run<NumDims>(
        {2, 4, 4, 4, 4, 4}, {1, 2, 8, 32, 128, 512}, {1024, 256, 64, 16, 4, 1});
    this->template Run<NumDims>({1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1});
}