layernorm_blockwise.cpp 6.07 KB
Newer Older
rocking5566's avatar
rocking5566 committed
1
2
3
4
5
6
7
8
9
10
11
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>

#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
12
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
rocking5566's avatar
rocking5566 committed
13
14
15
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"

#include "ck/library/utility/check_err.hpp"
16
17
18
19
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
20
#include "ck/library/utility/literals.hpp"
rocking5566's avatar
rocking5566 committed
21
22
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"

rocking5566's avatar
rocking5566 committed
23
24
25
26
27
28
using XDataType       = ck::half_t;
using GammaDataType   = ck::half_t;
using BetaDataType    = ck::half_t;
using YDataType       = ck::half_t;
using ComputeDataType = float;
using PassThrough     = ck::tensor_operation::element_wise::PassThrough;
rocking5566's avatar
rocking5566 committed
29
30
31
32

constexpr int Rank         = 2;
constexpr int NumReduceDim = 1;

rocking5566's avatar
rocking5566 committed
33
using DeviceInstance =
34
35
36
    ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
                                                          GammaDataType,
                                                          BetaDataType,
rocking5566's avatar
rocking5566 committed
37
                                                          ComputeDataType,
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
                                                          YDataType,
                                                          PassThrough,
                                                          Rank,
                                                          NumReduceDim,
                                                          256, // BlockSize
                                                          8,   // ClusterM
                                                          32,  // ClusterK
                                                          1,   // SliceM
                                                          8,   // SliceK
                                                          1,   // SrcVecDim (0=M, 1=K)
                                                          8,   // SrcScalarPerVector
                                                          1,   // GammaVecDim (0=M, 1=K)
                                                          8,   // GammaScalarPerVector
                                                          1,   // BetaVecDim (0=M, 1=K)
                                                          8,   // BetaScalarPerVector
                                                          8>;  // OutScalarPerVector
rocking5566's avatar
rocking5566 committed
54
55
56
57
58
59
60
61
62
63

int main()
{
    bool time_kernel = false;

    ck::index_t M      = 1024;
    ck::index_t N      = 1024;
    ck::index_t Stride = N;

    auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
64
        return HostTensorDescriptor({len}, {stride});
rocking5566's avatar
rocking5566 committed
65
66
67
    };

    auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) {
68
69
70
        using namespace ck::literals;

        return HostTensorDescriptor({row, col}, {stride, 1_uz});
rocking5566's avatar
rocking5566 committed
71
72
73
74
75
76
77
78
79
80
81
    };

    Tensor<XDataType> x(f_host_tensor_descriptor2d(M, N, Stride));
    Tensor<GammaDataType> gamma(f_host_tensor_descriptor1d(N, 1));
    Tensor<BetaDataType> beta(f_host_tensor_descriptor1d(N, 1));
    Tensor<YDataType> y(f_host_tensor_descriptor2d(M, N, Stride));

    x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
    gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
    beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});

82
83
84
85
    DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
    DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
    DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
    DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
rocking5566's avatar
rocking5566 committed
86
87
88
89
90
91
92
93
94

    x_dev.ToDevice(x.mData.data());
    gamma_dev.ToDevice(gamma.mData.data());
    beta_dev.ToDevice(beta.mData.data());

    auto device_instance = DeviceInstance{};
    auto argument_ptr    = device_instance.MakeArgumentPointer(
        {M, N},
        std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
rocking5566's avatar
rocking5566 committed
95
96
        {0, 1},
        {0, 1},
rocking5566's avatar
rocking5566 committed
97
        std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
rocking5566's avatar
rocking5566 committed
98
99
100
101
102
103
        {1},
        1e-4,
        x_dev.GetDeviceBuffer(),
        gamma_dev.GetDeviceBuffer(),
        beta_dev.GetDeviceBuffer(),
        y_dev.GetDeviceBuffer(),
104
105
        nullptr,
        nullptr,
rocking5566's avatar
rocking5566 committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        PassThrough{});

    if(!device_instance.IsSupportedArgument(argument_ptr.get()))
    {
        std::cout << "The runtime parameters are not supported" << std::endl;
        return 1;
    };

    auto invoker_ptr = device_instance.MakeInvokerPointer();
    invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});

    bool pass = true;
    {
        Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride));
        using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm<XDataType,
                                                                                 GammaDataType,
                                                                                 BetaDataType,
                                                                                 YDataType,
rocking5566's avatar
rocking5566 committed
124
                                                                                 ComputeDataType,
rocking5566's avatar
rocking5566 committed
125
126
127
128
129
130
131
132
133
134
135
                                                                                 PassThrough,
                                                                                 Rank,
                                                                                 NumReduceDim>;

        ReferenceInstance ref;
        auto ref_argument =
            ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4);
        auto ref_invoker = ref.MakeInvoker();
        ref_invoker.Run(ref_argument);

        y_dev.FromDevice(y.mData.data());
136
        pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results d1", 1e-3, 1e-3);
rocking5566's avatar
rocking5566 committed
137
138
139
    }
    return (pass ? 0 : 1);
}