test_layernorm2d_util.hpp 8.42 KB
Newer Older
rocking5566's avatar
rocking5566 committed
1
2
3
4
5
6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <iostream>
7
8
#include <vector>

rocking5566's avatar
rocking5566 committed
9
10
11
12
#include <gtest/gtest.h>

#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
rocking5566's avatar
rocking5566 committed
13
#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
rocking5566's avatar
rocking5566 committed
14

15
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
rocking5566's avatar
rocking5566 committed
16
#include "ck/library/utility/check_err.hpp"
17
#include "ck/library/utility/device_memory.hpp"
18
19
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/ranges.hpp"
rocking5566's avatar
rocking5566 committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

namespace ck {

template <typename Range>
std::string serialize_range(const Range& range)
{
    std::stringstream ss;
    for(auto& r : range)
    {
        ss << r << ", ";
    }
    std::string str = ss.str();
    return std::string(str.begin(), str.end() - 2);
}

template <typename Tuple>
rocking5566's avatar
rocking5566 committed
36
class TestLayernorm2d : public ::testing::Test
rocking5566's avatar
rocking5566 committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
{
    protected:
    using XDataType                             = std::tuple_element_t<0, Tuple>;
    using GammaDataType                         = std::tuple_element_t<1, Tuple>;
    using BetaDataType                          = std::tuple_element_t<2, Tuple>;
    using AccDataType                           = std::tuple_element_t<3, Tuple>;
    using YDataType                             = std::tuple_element_t<4, Tuple>;
    static constexpr index_t Rank               = std::tuple_element_t<5, Tuple>{}.value;
    static constexpr index_t NumReduceDim       = std::tuple_element_t<6, Tuple>{}.value;
    static constexpr index_t BlockSize          = std::tuple_element_t<7, Tuple>{}.value;
    static constexpr index_t MThreadClusterSize = std::tuple_element_t<8, Tuple>{}.value;
    static constexpr index_t KThreadClusterSize = std::tuple_element_t<9, Tuple>{}.value;
    static constexpr index_t MThreadSliceSize   = std::tuple_element_t<10, Tuple>{}.value;
    static constexpr index_t KThreadSliceSize   = std::tuple_element_t<11, Tuple>{}.value;
    static constexpr index_t XYSrcVectorDim     = std::tuple_element_t<12, Tuple>{}.value;
    static constexpr index_t XSrcVectorSize     = std::tuple_element_t<13, Tuple>{}.value;
rocking5566's avatar
rocking5566 committed
53
54
55
56
57
    static constexpr index_t GammaSrcVectorDim  = std::tuple_element_t<14, Tuple>{}.value;
    static constexpr index_t GammaSrcVectorSize = std::tuple_element_t<15, Tuple>{}.value;
    static constexpr index_t BetaSrcVectorDim   = std::tuple_element_t<16, Tuple>{}.value;
    static constexpr index_t BetaSrcVectorSize  = std::tuple_element_t<17, Tuple>{}.value;
    static constexpr index_t YDstVectorSize     = std::tuple_element_t<18, Tuple>{}.value;
rocking5566's avatar
rocking5566 committed
58
59
60
61
62
63
64
65
66
67
68
69

    using PassThrough = ck::tensor_operation::element_wise::PassThrough;

    using ReferenceInstance = tensor_operation::host::ReferenceLayernorm<XDataType,
                                                                         GammaDataType,
                                                                         BetaDataType,
                                                                         YDataType,
                                                                         AccDataType,
                                                                         PassThrough,
                                                                         Rank,
                                                                         NumReduceDim>;

rocking5566's avatar
rocking5566 committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    using DeviceInstance = tensor_operation::device::DeviceLayernormImpl<XDataType,
                                                                         GammaDataType,
                                                                         BetaDataType,
                                                                         AccDataType,
                                                                         YDataType,
                                                                         PassThrough,
                                                                         Rank,
                                                                         NumReduceDim,
                                                                         BlockSize,
                                                                         MThreadClusterSize,
                                                                         KThreadClusterSize,
                                                                         MThreadSliceSize,
                                                                         KThreadSliceSize,
                                                                         XYSrcVectorDim,
                                                                         XSrcVectorSize,
rocking5566's avatar
rocking5566 committed
85
                                                                         GammaSrcVectorDim,
rocking5566's avatar
rocking5566 committed
86
                                                                         GammaSrcVectorSize,
rocking5566's avatar
rocking5566 committed
87
                                                                         BetaSrcVectorDim,
rocking5566's avatar
rocking5566 committed
88
89
                                                                         BetaSrcVectorSize,
                                                                         YDstVectorSize>;
rocking5566's avatar
rocking5566 committed
90

rocking5566's avatar
rocking5566 committed
91
    TestLayernorm2d() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {}
rocking5566's avatar
rocking5566 committed
92

rocking5566's avatar
rocking5566 committed
93
94
95
96
97
98
    void RunSingle(const std::vector<index_t>& lengths,
                   const std::vector<index_t>& reduceDims,
                   const std::vector<index_t>& GammaLength,
                   const std::vector<index_t>& GammaStride,
                   const std::vector<index_t>& BetaLength,
                   const std::vector<index_t>& BetaStride)
rocking5566's avatar
rocking5566 committed
99
100
    {
        Tensor<XDataType> x(lengths);
rocking5566's avatar
rocking5566 committed
101
102
        Tensor<GammaDataType> gamma(GammaLength);
        Tensor<BetaDataType> beta(BetaLength);
rocking5566's avatar
rocking5566 committed
103
104
105
106
107
108
109
        Tensor<YDataType> y(lengths);
        Tensor<YDataType> y_ref(lengths);

        x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
        gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
        beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});

110
111
112
113
114
115
116
117
        DeviceMem x_dev(x.GetMemorySize());
        DeviceMem gamma_dev(gamma.GetMemorySize());
        DeviceMem beta_dev(beta.GetMemorySize());
        DeviceMem y_dev(y.GetMemorySize());

        x_dev.ToDevice(x.data());
        gamma_dev.ToDevice(gamma.data());
        beta_dev.ToDevice(beta.data());
rocking5566's avatar
rocking5566 committed
118

119
        using Indices = std::vector<ck::index_t>;
rocking5566's avatar
rocking5566 committed
120
121

        auto device_instance = DeviceInstance{};
122
123
124
125
126
127
128
129
130
131
132
133
134
        auto argument_ptr =
            device_instance.MakeArgumentPointer(lengths,
                                                ck::ranges::to<Indices>(x.GetStrides()),
                                                GammaStride,
                                                BetaStride,
                                                ck::ranges::to<Indices>(y.GetStrides()),
                                                reduceDims,
                                                1e-4,
                                                x_dev.GetDeviceBuffer(),
                                                gamma_dev.GetDeviceBuffer(),
                                                beta_dev.GetDeviceBuffer(),
                                                y_dev.GetDeviceBuffer(),
                                                PassThrough{});
rocking5566's avatar
rocking5566 committed
135
136
137
138
139
140
141
142
143
144
145
146

        if(!device_instance.IsSupportedArgument(argument_ptr.get()))
        {
            return;
        }

        auto invoker_ptr = device_instance.MakeInvokerPointer();
        invoker_ptr->Run(argument_ptr.get());

        ref_instance_invoker_.Run(
            {x, gamma, beta, y_ref, PassThrough{}, lengths, reduceDims, 1e-4});

147
        y_dev.FromDevice(y.data());
rocking5566's avatar
rocking5566 committed
148
149
150

        bool pass;

151
        if constexpr(std::is_same_v<XDataType, int8_t>)
rocking5566's avatar
rocking5566 committed
152
        {
153
            EXPECT_TRUE(pass = ck::utils::check_err(y, y_ref, "Error: Incorrect results!", 0, 1));
rocking5566's avatar
rocking5566 committed
154
155
156
        }
        else
        {
157
158
            EXPECT_TRUE(
                pass = ck::utils::check_err(y, y_ref, "Error: Incorrect results d1", 1e-3, 1e-3));
rocking5566's avatar
rocking5566 committed
159
160
161
162
163
164
165
166
167
168
169
        }

        if(!pass)
        {
            FAIL() << "Failure in input lengths = [" << serialize_range(lengths) << "], "
                   << "reduce dim = [" << serialize_range(reduceDims) << "].";
        }
    }

    void Run()
    {
rocking5566's avatar
rocking5566 committed
170
171
172
173
        std::vector<std::vector<index_t>> lengths = {
            {4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}};

        for(auto length : lengths)
rocking5566's avatar
rocking5566 committed
174
        {
rocking5566's avatar
rocking5566 committed
175
            this->RunSingle(length, {1}, {length[1]}, {0, 1}, {length[1]}, {0, 1});
rocking5566's avatar
rocking5566 committed
176
177
178
179
180
        }
    }

    typename ReferenceInstance::Invoker ref_instance_invoker_;
};
rocking5566's avatar
rocking5566 committed
181

rocking5566's avatar
rocking5566 committed
182
} // namespace ck