test_softmax_util.hpp 6.78 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
5
#pragma once

6
#include <iostream>
7
8
#include <vector>

Chao Liu's avatar
Chao Liu committed
9
10
11
#include <gtest/gtest.h>

#include "ck/ck.hpp"
Adam Osewski's avatar
Adam Osewski committed
12
13
#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
14
#include "ck/utility/number.hpp"
Chao Liu's avatar
Chao Liu committed
15

16
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
Chao Liu's avatar
Chao Liu committed
17
#include "ck/library/utility/check_err.hpp"
18
#include "ck/library/utility/device_memory.hpp"
19
#include "ck/library/utility/host_tensor.hpp"
20
21
22

namespace ck {

23
24
25
26
27
28
29
30
31
32
33
34
template <typename Range>
std::string serialize_range(const Range& range)
{
    std::stringstream ss;
    for(auto& r : range)
    {
        ss << r << ", ";
    }
    std::string str = ss.str();
    return std::string(str.begin(), str.end() - 2);
}

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
template <typename Tuple>
class TestSoftmax : public ::testing::Test
{
    protected:
    using InDataType                            = std::tuple_element_t<0, Tuple>;
    using AccDataType                           = std::tuple_element_t<1, Tuple>;
    using OutDataType                           = std::tuple_element_t<2, Tuple>;
    static constexpr index_t Rank               = std::tuple_element_t<3, Tuple>{}.value;
    static constexpr index_t NumReduceDim       = std::tuple_element_t<4, Tuple>{}.value;
    static constexpr index_t BlockSize          = std::tuple_element_t<5, Tuple>{}.value;
    static constexpr index_t MThreadClusterSize = std::tuple_element_t<6, Tuple>{}.value;
    static constexpr index_t KThreadClusterSize = std::tuple_element_t<7, Tuple>{}.value;
    static constexpr index_t MThreadSliceSize   = std::tuple_element_t<8, Tuple>{}.value;
    static constexpr index_t KThreadSliceSize   = std::tuple_element_t<9, Tuple>{}.value;
    static constexpr index_t InSrcVectorDim     = std::tuple_element_t<10, Tuple>{}.value;
    static constexpr index_t InSrcVectorSize    = std::tuple_element_t<11, Tuple>{}.value;
    static constexpr index_t OutDstVectorSize   = std::tuple_element_t<12, Tuple>{}.value;

    using ReferenceInstance =
        tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;

Adam Osewski's avatar
Adam Osewski committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    using PassThrough = ck::tensor_operation::element_wise::PassThrough;

    using DeviceInstance = tensor_operation::device::DeviceSoftmaxImpl<InDataType,
                                                                       AccDataType,
                                                                       OutDataType,
                                                                       PassThrough,
                                                                       PassThrough,
                                                                       Rank,
                                                                       NumReduceDim,
                                                                       BlockSize,
                                                                       MThreadClusterSize,
                                                                       KThreadClusterSize,
                                                                       MThreadSliceSize,
                                                                       KThreadSliceSize,
                                                                       InSrcVectorDim,
                                                                       InSrcVectorSize,
                                                                       OutDstVectorSize>;
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    TestSoftmax() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {}

    void RunSingle(std::vector<index_t> in_length, AccDataType alpha, AccDataType beta)
    {
        std::vector<index_t> reduce_dims(NumReduceDim);
        std::iota(reduce_dims.begin(), reduce_dims.end(), Rank - NumReduceDim);

        Tensor<InDataType> in(in_length);
        Tensor<OutDataType> out(in_length);

        in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
        out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});

        Tensor<OutDataType> out_ref(out);

89
90
91
92
        DeviceMem in_dev(in.GetMemorySize());
        DeviceMem out_dev(out.GetMemorySize());
        in_dev.ToDevice(in.data());
        out_dev.ToDevice(out.data());
93

94
95
        std::vector<index_t> i_in_lengths(in.GetLengths().begin(), in.GetLengths().end());
        std::vector<index_t> i_in_strides(in.GetStrides().begin(), in.GetStrides().end());
96
97
98
99
100

        auto device_instance = DeviceInstance{};
        auto argument_ptr    = device_instance.MakeArgumentPointer(i_in_lengths,
                                                                i_in_strides,
                                                                reduce_dims,
101
102
                                                                &alpha,
                                                                &beta,
103
                                                                in_dev.GetDeviceBuffer(),
Adam Osewski's avatar
Adam Osewski committed
104
105
106
                                                                out_dev.GetDeviceBuffer(),
                                                                PassThrough{},
                                                                PassThrough{});
107
108
109

        if(!device_instance.IsSupportedArgument(argument_ptr.get()))
        {
110
111
112
113
            // std::cout << "Skipped due to unsupported argument: "
            //           << "input lengths = [" << serialize_range(in_length) << "], "
            //           << "scaler = [" << alpha << ", " << beta << "]." << std::endl;
            return;
114
115
116
117
118
        }

        auto invoker_ptr = device_instance.MakeInvokerPointer();
        invoker_ptr->Run(argument_ptr.get());

119
        ref_instance_invoker_.Run({in, out_ref, alpha, beta, reduce_dims});
120

121
        out_dev.FromDevice(out.data());
122
123
124

        bool pass;

125
        if constexpr(std::is_same_v<InDataType, int8_t>)
126
        {
127
128
            EXPECT_TRUE(pass =
                            ck::utils::check_err(out, out_ref, "Error: Incorrect results!", 0, 1));
129
130
131
        }
        else
        {
132
            EXPECT_TRUE(pass = ck::utils::check_err(out, out_ref));
133
134
135
136
137
138
139
        }

        if(!pass)
        {
            FAIL() << "Failure in input lengths = [" << serialize_range(in_length) << "], "
                   << "scaler = [" << alpha << ", " << beta << "].";
        }
140
141
142
143
144
145
146
147
    }

    void Run()
    {
        for(auto in_length : this->in_lengths_)
        {
            for(auto scale : this->scales_)
            {
148
                this->RunSingle(in_length, scale[0], scale[1]);
149
150
151
152
            }
        }
    }

153
154
155
    std::vector<std::vector<index_t>> in_lengths_ = {
        {1, 8, 128}, {2, 128, 1024}, {3, 9, 1032}, {4, 4, 2048}, {8, 1, 8192}};
    std::vector<std::vector<AccDataType>> scales_ = {{1, 0}, {1, 1}, {0, 1}, {2, 2}};
156
157
158
159

    typename ReferenceInstance::Invoker ref_instance_invoker_;
};
} // namespace ck