layernorm_fp16.cpp 2.34 KB
Newer Older
rocking's avatar
rocking committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
rocking's avatar
rocking committed
3
4
5

#include "common.hpp"

6
7
8
9
10
11
12
13
14
using XDataType              = ck::half_t;
using GammaDataType          = ck::half_t;
using BetaDataType           = ck::half_t;
using YDataType              = ck::half_t;
using SaveMeanInvStdDataType = float;
using ComputeDataType        = float;
using PassThrough            = ck::tensor_operation::element_wise::PassThrough;

#define SAVE_MEAN_INV_STD
rocking's avatar
rocking committed
15
16
17
18
19
20
21
22
23
24

constexpr int Rank         = 2;
constexpr int NumReduceDim = 1;

using DeviceInstance =
    ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
                                                          GammaDataType,
                                                          BetaDataType,
                                                          ComputeDataType,
                                                          YDataType,
25
                                                          SaveMeanInvStdDataType,
rocking's avatar
rocking committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
                                                          PassThrough,
                                                          Rank,
                                                          NumReduceDim,
                                                          256, // BlockSize
                                                          8,   // ClusterM
                                                          32,  // ClusterK
                                                          1,   // SliceM
                                                          8,   // SliceK
                                                          1,   // XYVectorDim (0=M, 1=K)
                                                          8,   // SrcScalarPerVector
                                                          1,   // GammaVecDim (0=M, 1=K)
                                                          8,   // GammaScalarPerVector
                                                          1,   // BetaVecDim (0=M, 1=K)
                                                          8,   // BetaScalarPerVector
40
41
                                                          8,   // YScalarPerVector
                                                          1>;  // SaveMeanInvStdScalarPerVector
rocking's avatar
rocking committed
42
43
44
#include "run_layernorm_example.inc"

int main() { return run_groupnorm_example<DeviceInstance>(); }