"vscode:/vscode.git/clone" did not exist on "271269a5caba877ab379afb2d1adf7224d3181d4"
test_softmax_rank3.cpp 841 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include <algorithm>
#include <stdexcept>
#include <vector>

#include "gtest/gtest.h"
#include "test_softmax_util.hpp"

template <ck::index_t N>
using I = ck::Number<N>;

using F16 = ck::half_t;
using F32 = float;
using I8  = int8_t;

template <typename Tuple>
class TestSoftmax : public ck::TestSoftmax<Tuple>
{
};

// clang-format off
using KernelTypes = ::testing::Types<
    //         InDataType, AccDataType, OutDataType, Rank
    std::tuple<       F16,         F32,         F16,    I<3>>,
    std::tuple<       F32,         F32,         F32,    I<3>>,
    std::tuple<        I8,         F32,          I8,    I<3>>
    >;
// clang-format on

TYPED_TEST_SUITE(TestSoftmax, KernelTypes);

#include "test_softmax_ut_cases.inc"