"vscode:/vscode.git/clone" did not exist on "7f3ee861aee6b0b1d23fac1a80abfd18c9d94229"
layernorm_blockwise.cpp 6.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>

#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_common_util.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
rocking's avatar
rocking committed
20
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
21
22
23
24
25
26

using XDataType     = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType  = ck::half_t;
using YDataType     = ck::half_t;
using AccDataType   = float;
rocking's avatar
rocking committed
27
using PassThrough   = ck::tensor_operation::element_wise::PassThrough;
28
29
30
31
32
33
34
35
36

constexpr int Rank         = 2;
constexpr int NumReduceDim = 1;

using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
                                                                     GammaDataType,
                                                                     BetaDataType,
                                                                     AccDataType,
                                                                     YDataType,
rocking's avatar
rocking committed
37
                                                                     PassThrough,
38
39
40
41
42
43
44
45
                                                                     Rank,
                                                                     NumReduceDim,
                                                                     256, // BlockSize
                                                                     8,   // ClusterM
                                                                     32,  // ClusterK
                                                                     1,   // SliceM
                                                                     8,   // SliceK
                                                                     1,   // SrcVecDim (0=M, 1=K)
46
47
48
                                                                     8,   // SrcScalarPerVector
                                                                     8,   // GammaScalarPerVector
                                                                     8,   // BetaScalarPerVector
rocking's avatar
rocking committed
49
50
                                                                     1>;  // OutScalarPerVector

51
52
53
54
55
56
int main()
{
    bool time_kernel = false;

    ck::index_t M      = 1024;
    ck::index_t N      = 1024;
rocking's avatar
rocking committed
57
    ck::index_t Stride = N;
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

    auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
        return HostTensorDescriptor(std::vector<std::size_t>({len}),
                                    std::vector<std::size_t>({stride}));
    };

    auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) {
        return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                    std::vector<std::size_t>({stride, 1}));
    };

    Tensor<XDataType> x(f_host_tensor_descriptor2d(M, N, Stride));
    Tensor<GammaDataType> gamma(f_host_tensor_descriptor1d(N, 1));
    Tensor<BetaDataType> beta(f_host_tensor_descriptor1d(N, 1));
    Tensor<YDataType> y(f_host_tensor_descriptor2d(M, N, Stride));

    x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
    gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
    beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});

    DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpace());
    DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpace());
    DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpace());
    DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpace());

    x_dev.ToDevice(x.mData.data());
rocking's avatar
rocking committed
84
85
    gamma_dev.ToDevice(gamma.mData.data());
    beta_dev.ToDevice(beta.mData.data());
86
87

    auto device_instance = DeviceInstance{};
rocking's avatar
rocking committed
88
89
90
91
92
93
94
95
96
97
    auto argument_ptr    = device_instance.MakeArgumentPointer(
        {M, N},
        std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
        std::vector<ck::index_t>{gamma.mDesc.GetStrides().begin(), gamma.mDesc.GetStrides().end()},
        std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(), beta.mDesc.GetStrides().end()},
        {1},
        1e-4,
        x_dev.GetDeviceBuffer(),
        gamma_dev.GetDeviceBuffer(),
        beta_dev.GetDeviceBuffer(),
rocking's avatar
rocking committed
98
99
        y_dev.GetDeviceBuffer(),
        PassThrough{});
100
101
102
103
104
105
106
107
108
109
110

    if(!device_instance.IsSupportedArgument(argument_ptr.get()))
    {
        std::cout << "The runtime parameters are not supported" << std::endl;
        return 1;
    };

    auto invoker_ptr = device_instance.MakeInvokerPointer();
    invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});

    bool pass = true;
rocking's avatar
rocking committed
111
112
    {
        Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride));
rocking's avatar
rocking committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm<XDataType,
                                                                                 GammaDataType,
                                                                                 BetaDataType,
                                                                                 YDataType,
                                                                                 AccDataType,
                                                                                 PassThrough,
                                                                                 Rank,
                                                                                 NumReduceDim>;

        ReferenceInstance ref;
        auto ref_argument =
            ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4);
        auto ref_invoker = ref.MakeInvoker();
        ref_invoker.Run(ref_argument);
rocking's avatar
rocking committed
127
128
129
130
131

        y_dev.FromDevice(y.mData.data());
        pass &=
            ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
    }
132
133
    return (pass ? 0 : 1);
}