test_softmax_util.hpp 5.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <vector>
#include <iostream>
#include "gtest/gtest.h"

#include "config.hpp"
#include "host_tensor.hpp"
#include "check_err.hpp"
#include "number.hpp"
#include "reference_softmax.hpp"
#include "device_softmax.hpp"

namespace ck {

template <typename Tuple>
class TestSoftmax : public ::testing::Test
{
    protected:
    using InDataType                            = std::tuple_element_t<0, Tuple>;
    using AccDataType                           = std::tuple_element_t<1, Tuple>;
    using OutDataType                           = std::tuple_element_t<2, Tuple>;
    static constexpr index_t Rank               = std::tuple_element_t<3, Tuple>{}.value;
    static constexpr index_t NumReduceDim       = std::tuple_element_t<4, Tuple>{}.value;
    static constexpr index_t BlockSize          = std::tuple_element_t<5, Tuple>{}.value;
    static constexpr index_t MThreadClusterSize = std::tuple_element_t<6, Tuple>{}.value;
    static constexpr index_t KThreadClusterSize = std::tuple_element_t<7, Tuple>{}.value;
    static constexpr index_t MThreadSliceSize   = std::tuple_element_t<8, Tuple>{}.value;
    static constexpr index_t KThreadSliceSize   = std::tuple_element_t<9, Tuple>{}.value;
    static constexpr index_t InSrcVectorDim     = std::tuple_element_t<10, Tuple>{}.value;
    static constexpr index_t InSrcVectorSize    = std::tuple_element_t<11, Tuple>{}.value;
    static constexpr index_t OutDstVectorSize   = std::tuple_element_t<12, Tuple>{}.value;

    using ReferenceInstance =
        tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;

    using DeviceInstance = tensor_operation::device::DeviceSoftmax<InDataType,
                                                                   AccDataType,
                                                                   OutDataType,
                                                                   Rank,
                                                                   NumReduceDim,
                                                                   BlockSize,
                                                                   MThreadClusterSize,
                                                                   KThreadClusterSize,
                                                                   MThreadSliceSize,
                                                                   KThreadSliceSize,
                                                                   InSrcVectorDim,
                                                                   InSrcVectorSize,
                                                                   OutDstVectorSize>;

    TestSoftmax() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {}

    void RunSingle(std::vector<index_t> in_length, AccDataType alpha, AccDataType beta)
    {
        std::vector<index_t> reduce_dims(NumReduceDim);
        std::iota(reduce_dims.begin(), reduce_dims.end(), Rank - NumReduceDim);

        Tensor<InDataType> in(in_length);
        Tensor<OutDataType> out(in_length);

        in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
        out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});

        Tensor<OutDataType> out_ref(out);

        DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace());
        DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace());
        in_dev.ToDevice(in.mData.data());
        out_dev.ToDevice(out.mData.data());

        std::vector<index_t> i_in_lengths(in.mDesc.GetLengths().begin(),
                                          in.mDesc.GetLengths().end());
        std::vector<index_t> i_in_strides(in.mDesc.GetStrides().begin(),
                                          in.mDesc.GetStrides().end());

        auto device_instance = DeviceInstance{};
        auto argument_ptr    = device_instance.MakeArgumentPointer(i_in_lengths,
                                                                i_in_strides,
                                                                reduce_dims,
                                                                alpha,
                                                                beta,
                                                                in_dev.GetDeviceBuffer(),
                                                                out_dev.GetDeviceBuffer());

        if(!device_instance.IsSupportedArgument(argument_ptr.get()))
        {
            FAIL() << "Unsupported argument";
        }

        auto invoker_ptr = device_instance.MakeInvokerPointer();
        invoker_ptr->Run(argument_ptr.get());

        ref_instance_invoker_.Run({in, out_ref, alpha, beta, Rank, reduce_dims});

        out_dev.FromDevice(out.mData.data());
        EXPECT_TRUE(ck::utils::check_err(out.mData, out_ref.mData));
    }

    void Run()
    {
        for(auto in_length : this->in_lengths_)
        {
            for(auto scale : this->scales_)
            {
                this->RunSingle(in_length, std::get<0>(scale), std::get<1>(scale));
            }
        }
    }

    std::vector<std::vector<index_t>> in_lengths_ = {{1, 8, 128}, {2, 128, 1024}, {3, 9, 1032}};
    std::vector<std::tuple<AccDataType, AccDataType>> scales_ = {{1, 0}, {2, 2}, {0, 1}};

    typename ReferenceInstance::Invoker ref_instance_invoker_;
};
} // namespace ck