test_softmax_rank4.cpp 803 Bytes
Newer Older
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6
7
8
9
10
11
12

#include <algorithm>
#include <stdexcept>
#include <vector>

#include "gtest/gtest.h"
#include "test_softmax_util.hpp"

template <ck::index_t N>
using I = ck::Number<N>;
13
#ifdef __fp16__
14
using F16 = ck::half_t;
15
#endif
16
17
18
19
20
21
22
23
24
25
using F32 = float;

template <typename Tuple>
class TestSoftmax : public ck::TestSoftmax<Tuple>
{
};

// clang-format off
using KernelTypes = ::testing::Types<
    //         InDataType, AccDataType, OutDataType, Rank
26
#ifdef __fp16__
27
    std::tuple<       F16,         F32,         F16,    I<4>>,
28
#endif
29
    std::tuple<       F32,         F32,         F32,    I<4>>
30
31
32
33
34
35
    >;
// clang-format on

TYPED_TEST_SUITE(TestSoftmax, KernelTypes);

#include "test_softmax_ut_cases.inc"