"vscode:/vscode.git/clone" did not exist on "753724d867f0938dd8d020d51410e0393bfacf11"
layernorm_fp16.cpp 2.07 KB
Newer Older
rocking's avatar
rocking committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include "common.hpp"

using XDataType       = ck::half_t;
using GammaDataType   = ck::half_t;
using BetaDataType    = ck::half_t;
using YDataType       = ck::half_t;
using ComputeDataType = float;
using PassThrough     = ck::tensor_operation::element_wise::PassThrough;

constexpr int Rank         = 2;
constexpr int NumReduceDim = 1;

using DeviceInstance =
    ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
                                                          GammaDataType,
                                                          BetaDataType,
                                                          ComputeDataType,
                                                          YDataType,
                                                          PassThrough,
                                                          Rank,
                                                          NumReduceDim,
                                                          256, // BlockSize
                                                          8,   // ClusterM
                                                          32,  // ClusterK
                                                          1,   // SliceM
                                                          8,   // SliceK
                                                          1,   // XYVectorDim (0=M, 1=K)
                                                          8,   // SrcScalarPerVector
                                                          1,   // GammaVecDim (0=M, 1=K)
                                                          8,   // GammaScalarPerVector
                                                          1,   // BetaVecDim (0=M, 1=K)
                                                          8,   // BetaScalarPerVector
                                                          8>;  // OutScalarPerVector
#include "run_layernorm_example.inc"

int main() { return run_groupnorm_example<DeviceInstance>(); }