Commit 4e14a894 authored by rocking's avatar rocking
Browse files

refactor api

parent 8c3d43cf
......@@ -9,7 +9,8 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE layernorm2d_fwd_api.cpp ${INST
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
# list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
......
......@@ -127,10 +127,6 @@ int main(int argc, char* argv[])
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
if(data_type == "fp32")
{
return run<float>(arg_parser) ? 0 : -2;
}
return -3;
}
......@@ -3,26 +3,74 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
#include "layernorm_dispatch.hpp"
// clang-format off
// template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem =
ck_tile::BlockLayernorm2dFwdProblem<typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType,
typename LayerNormTypeConfig<DataType>::ComputeDataType,
typename LayerNormTypeConfig<DataType>::YDataType,
typename LayerNormTypeConfig<DataType>::MeanDataType,
typename LayerNormTypeConfig<DataType>::InvStdDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
const dim3 grids = Kernel::GridSize(a.M);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{},
grids,
blocks,
0,
a.p_x,
a.p_gamma,
a.p_beta,
a.p_y,
a.p_mean,
a.p_invStd,
a.epsilon,
a.M,
a.N));
}
template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::fp16_t,
NRepeat,
NThread,
VectorAccessSize,
false,
false,
kTwoPass>;
using S = ck_tile::stream_config;
using A = layernorm2d_fwd_args;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A);
......@@ -3,26 +3,84 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
#include "layernorm_dispatch.hpp"
// clang-format off
// template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem =
ck_tile::BlockLayernorm2dFwdProblem<typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType,
typename LayerNormTypeConfig<DataType>::ComputeDataType,
typename LayerNormTypeConfig<DataType>::YDataType,
typename LayerNormTypeConfig<DataType>::MeanDataType,
typename LayerNormTypeConfig<DataType>::InvStdDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
const dim3 grids = Kernel::GridSize(a.M);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{},
grids,
blocks,
0,
a.p_x,
a.p_gamma,
a.p_beta,
a.p_y,
a.p_mean,
a.p_invStd,
a.epsilon,
a.M,
a.N));
}
template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::fp16_t,
NRepeat,
NThread,
VectorAccessSize,
true,
false,
kTwoPass>;
using S = const ck_tile::stream_config;
using A = layernorm2d_fwd_args;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, true>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, false>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, true>>(const S&, A);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
#ifdef CK_TILE_LAYERNORM2D_FWD_FP32_DEFAULT
template float run_layernorm<float, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 4, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 16, 64, 2, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 4, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 16, 64, 2, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
#endif
// clang-format on
......@@ -49,6 +49,38 @@ struct layernorm2d_fwd_args
ck_tile::index_t N;
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
struct layernorm2d_fwd_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr ck_tile::index_t MRepeat = 1;
static_assert(NThread <= 64, "We only support intra-wave reduction");
static constexpr ck_tile::index_t WaveNum = NThread / 16;
using thread_tile = ck_tile::sequence<MRepeat, NRepeat, VectorAccessSize>;
using warp_tile =
ck_tile::sequence<MRepeat * 64 / NThread, NRepeat * NThread * VectorAccessSize>;
using block_tile =
ck_tile::sequence<MRepeat * WaveNum * 64 / NThread, NRepeat * NThread * VectorAccessSize>;
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);
// This is the public API, will be generated by script
struct layernorm2d_fwd_traits
{
......
......@@ -2,7 +2,16 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
#include "layernorm2d_fwd.hpp"
template <typename DataType,
ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kPadN,
bool kTwoPass = false>
using trait_ =
layernorm2d_fwd_traits_<DataType, NRepeat, NThread, VectorAccessSize, kPadN, false, kTwoPass>;
float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a,
......@@ -11,182 +20,79 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
float r = -1;
if(t.data_type.compare("fp16") == 0)
{
// Disable all vector 8fp16 read/write instances as it has performance issue regarding
// compiler
#if 0
if(a.N % 8 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(a, s);
}
else
{
return a.N == 2048 ? run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(a, s);
}
}
else if(a.N % 4 == 0)
#endif
if(a.N % 4 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(a, s);
}
else if(a.N <= 2048)
{
return a.N == 2048 ? run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(a, s);
}
else
{
return a.N % 2048 == 0 ? run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(a, s)
: run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(a, s);
}
}
else if(a.N % 2 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(a, s);
return a.N == 128 ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 32, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 32, 4, true>>(s, a);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(a, s);
return a.N == 256 ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 64, 4, true>>(s, a);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(a, s);
return a.N == 512 ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 64, 4, true>>(s, a);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(a, s);
return a.N == 1024
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 2048)
{
return a.N == 2048 ? run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(a, s);
return a.N == 2048
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 4, true>>(s, a);
}
else
{
return a.N % 2048 == 0
? run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(a, s)
: run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(a, s);
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 4, false, true>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 4, true, true>>(s, a);
}
}
}
#ifdef CK_TILE_LAYERNORM2D_FWD_FP32_DEFAULT
else if(t.data_type.compare("fp32") == 0)
{
if(a.N % 4 == 0)
else if(a.N % 2 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<float, 1, 32, 4, false>(a, s)
: run_layernorm<float, 1, 32, 4, true>(a, s);
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 64, 2, true>>(s, a);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<float, 1, 64, 4, false>(a, s)
: run_layernorm<float, 1, 64, 4, true>(a, s);
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 64, 2, true>>(s, a);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<float, 2, 64, 4, false>(a, s)
: run_layernorm<float, 2, 64, 4, true>(a, s);
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<float, 4, 64, 4, false>(a, s)
: run_layernorm<float, 4, 64, 4, true>(a, s);
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 2, true>>(s, a);
}
else if(a.N <= 2048)
{
return a.N == 2048 ? run_layernorm<float, 8, 64, 4, false>(a, s)
: run_layernorm<float, 8, 64, 4, true>(a, s);
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 16, 64, 2, true>>(s, a);
}
else
{
return a.N % 2048 == 0 ? run_layernorm<float, 8, 64, 4, false, true>(a, s)
: run_layernorm<float, 8, 64, 4, true, true>(a, s);
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 16, 64, 2, true, true>>(s, a);
}
}
else if(a.N % 2 == 0)
else
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<float, 1, 64, 2, false>(a, s)
: run_layernorm<float, 1, 64, 2, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<float, 2, 64, 2, false>(a, s)
: run_layernorm<float, 2, 64, 2, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<float, 4, 64, 2, false>(a, s)
: run_layernorm<float, 4, 64, 2, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<float, 8, 64, 2, false>(a, s)
: run_layernorm<float, 8, 64, 2, true>(a, s);
}
else if(a.N <= 2048)
{
return a.N == 2048 ? run_layernorm<float, 16, 64, 2, false>(a, s)
: run_layernorm<float, 16, 64, 2, true>(a, s);
}
else
{
return a.N % 2048 == 0 ? run_layernorm<float, 16, 64, 2, false, true>(a, s)
: run_layernorm<float, 16, 64, 2, true, true>(a, s);
}
return a.N <= 2048
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 64, 1, true, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 64, 1, true, true>>(s, a);
}
}
#endif
if (r < 0)
if(r < 0)
throw std::runtime_error("Without supported instances!");
return r;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core/numeric/integer.hpp>
#include <ck_tile/host.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include "layernorm2d_fwd.hpp"
template <typename InOutDataType,
ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kPadN,
bool kSaveMeanInvStd,
bool kTwoPass>
struct layernorm_dispatch
{
static constexpr ck_tile::index_t MRepeat = 1;
static_assert(NThread <= 64, "We only support intra-wave reduction");
static constexpr ck_tile::index_t WaveNum = NThread / 16;
// clang-format off
using thread_tile = ck_tile::sequence<MRepeat, NRepeat, VectorAccessSize>;
using warp_tile = ck_tile::sequence<MRepeat*64/NThread, NRepeat * NThread*VectorAccessSize>;
using block_tile = ck_tile::sequence<MRepeat*WaveNum*64/NThread, NRepeat * NThread*VectorAccessSize>;
// clang-format on
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem<
typename LayerNormTypeConfig<InOutDataType>::XDataType,
typename LayerNormTypeConfig<InOutDataType>::GammaDataType,
typename LayerNormTypeConfig<InOutDataType>::BetaDataType,
typename LayerNormTypeConfig<InOutDataType>::ComputeDataType,
typename LayerNormTypeConfig<InOutDataType>::YDataType,
typename LayerNormTypeConfig<InOutDataType>::MeanDataType,
typename LayerNormTypeConfig<InOutDataType>::InvStdDataType,
Shape,
kPadN,
kSaveMeanInvStd,
kTwoPass>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
static float Run(const layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
using k_ = Kernel;
const dim3 grids = k_::GridSize(param.M);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
return ck_tile::launch_kernel(stream,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{},
grids,
blocks,
0,
param.p_x,
param.p_gamma,
param.p_beta,
param.p_y,
param.p_mean,
param.p_invStd,
param.epsilon,
param.M,
param.N));
};
};
template <typename InOutDataType,
ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kPadN,
bool kTwoPass = false>
float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
// TODO - Add SaveMeanInvStd instance
constexpr bool kSaveMeanInvStd = false;
return layernorm_dispatch<InOutDataType,
NRepeat,
NThread,
VectorAccessSize,
kPadN,
kSaveMeanInvStd,
kTwoPass>::Run(param, stream);
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment