Commit cdfceb0a authored by Astha Rai's avatar Astha Rai
Browse files

Merge branch 'codegen_hiprtc' of github.com:ROCm/composable_kernel into codegen_hiprtc

parents b46349df 3b9a77df
...@@ -97,6 +97,10 @@ if(DL_KERNELS) ...@@ -97,6 +97,10 @@ if(DL_KERNELS)
add_definitions(-DDL_KERNELS) add_definitions(-DDL_KERNELS)
set(CK_ENABLE_DL_KERNELS "ON") set(CK_ENABLE_DL_KERNELS "ON")
endif() endif()
if(DPP_KERNELS)
add_definitions(-DDPP_KERNELS)
set(CK_ENABLE_DPP_KERNELS "ON")
endif()
option(CK_USE_CODEGEN "Enable codegen library" OFF) option(CK_USE_CODEGEN "Enable codegen library" OFF)
if(CK_USE_CODEGEN) if(CK_USE_CODEGEN)
add_definitions(-DCK_USE_CODEGEN) add_definitions(-DCK_USE_CODEGEN)
......
...@@ -7,7 +7,7 @@ Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) ...@@ -7,7 +7,7 @@ Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou)
Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan)
SPDX-License-Identifier: MIT SPDX-License-Identifier: MIT
Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
......
...@@ -153,6 +153,9 @@ Additional cmake flags can be used to significantly speed-up the build: ...@@ -153,6 +153,9 @@ Additional cmake flags can be used to significantly speed-up the build:
`batched_gemm_multi_d_dl`. These instances are useful on architectures like the NAVI2x, as most `batched_gemm_multi_d_dl`. These instances are useful on architectures like the NAVI2x, as most
other platforms have faster instances, such as `xdl` or `wmma`, available. other platforms have faster instances, such as `xdl` or `wmma`, available.
* `DPP_KERNELS` (default is OFF) must be set to ON in order to build instances, such as `gemm_dpp`.
These instances are useful on architectures like the NAVI2x, as most other platforms have faster instances, such as `xdl` or `wmma`, available.
* `CK_USE_FP8_ON_UNSUPPORTED_ARCH` (default is OFF) must be set to ON in order to build instances, * `CK_USE_FP8_ON_UNSUPPORTED_ARCH` (default is OFF) must be set to ON in order to build instances,
such as `gemm_universal`, `gemm_universal_streamk` and `gemm_multiply_multiply` for fp8 data type for GPU targets which do not have native support for fp8 data type, such as gfx908 or gfx90a. These instances are useful on such as `gemm_universal`, `gemm_universal_streamk` and `gemm_multiply_multiply` for fp8 data type for GPU targets which do not have native support for fp8 data type, such as gfx908 or gfx90a. These instances are useful on
architectures like the MI100/MI200 for the functional support only. architectures like the MI100/MI200 for the functional support only.
......
rocm-docs-core==1.12.1 rocm-docs-core==1.13.0
sphinxcontrib-bibtex==2.6.3 sphinxcontrib-bibtex==2.6.3
...@@ -103,7 +103,7 @@ requests==2.32.3 ...@@ -103,7 +103,7 @@ requests==2.32.3
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.12.1 rocm-docs-core==1.13.0
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via pybtex # via pybtex
......
...@@ -54,9 +54,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -54,9 +54,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list(REMOVE_ITEM FILE_NAME "${source}") list(REMOVE_ITEM FILE_NAME "${source}")
endif() endif()
endforeach() endforeach()
#Do not build any DPP examples if DL_KERNELS not set #Do not build any DPP examples if DPP_KERNELS not set
foreach(source IN LISTS FILE_NAME) foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dpp") if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp")
message("removing dpp example ${source} ") message("removing dpp example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}") list(REMOVE_ITEM FILE_NAME "${source}")
endif() endif()
......
...@@ -48,8 +48,8 @@ using fmha_dtype_{F_idx} = {F_dtype}; ...@@ -48,8 +48,8 @@ using fmha_dtype_{F_idx} = {F_dtype};
using fmha_mask_{F_idx} = {F_mask}; using fmha_mask_{F_idx} = {F_mask};
namespace {{ namespace {{
template <bool kHasUnevenSplits> template <bool kHasUnevenSplits, bool kMergeNumHeadGroupsSeqLenQ = false>
struct kernel_runner {{ struct instance {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile, using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
...@@ -64,11 +64,12 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, ...@@ -64,11 +64,12 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_dpad}, {F_dpad},
{F_dvpad}, {F_dvpad},
{F_bias}, {F_bias},
false, /*kHasBiasGrad=*/false,
{F_lse}, {F_lse},
{F_squant}, {F_squant},
{F_pagedkv}, {F_pagedkv},
kHasUnevenSplits, kHasUnevenSplits,
kMergeNumHeadGroupsSeqLenQ,
{F_occupancy}>; {F_occupancy}>;
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...@@ -115,28 +116,50 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F ...@@ -115,28 +116,50 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
#include <iostream> #include <iostream>
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wtautological-compare"
namespace {{
template <bool kHasUnevenSplits>
void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
instance<kHasUnevenSplits, /*kMergeNumHeadGroupsSeqLenQ=*/true>::run(s, a);
}} else {{
instance<kHasUnevenSplits>::run(s, a);
}}
}} else {{
instance<kHasUnevenSplits>::run(s, a);
}}
}}
}} // anonymous namespace
#pragma clang diagnostic pop
template<> template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if constexpr({F_mode} == false) {{ // batch mode if constexpr({F_mode} == false) {{ // batch mode
// we don't check every seqlen_k values for kvcache // we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{ if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<true>::run(s, a); run_instance</*kHasUnevenSplits=*/true>(s, a);
// make sure F_bn0 is divisible by F_bk1 // make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<false>::run(s, a); run_instance</*kHasUnevenSplits=*/false>(s, a);
}} else {{ }} else {{
kernel_runner<true>::run(s, a); run_instance</*kHasUnevenSplits=*/true>(s, a);
}} }}
}} else {{ }} else {{
kernel_runner<true>::run(s, a); run_instance</*kHasUnevenSplits=*/true>(s, a);
}} }}
}} }}
template<> template<>
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>() std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
{{ {{
using k_ = kernel_runner<true>::fmha_kernel; /// FIXME: choose real kernel type using k_ = instance<true>::fmha_kernel; /// FIXME: choose real kernel type
return k_::GetName(); return k_::GetName();
}} }}
""" """
...@@ -146,7 +169,7 @@ using fmha_dtype_{F_idx} = {F_dtype}; ...@@ -146,7 +169,7 @@ using fmha_dtype_{F_idx} = {F_dtype};
namespace {{ namespace {{
template <ck_tile::index_t kLogMaxSplits> template <ck_tile::index_t kLogMaxSplits>
struct kernel_runner {{ struct instance {{
using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad},
{F_dvpad}, {F_dvpad},
{F_lse}, {F_lse},
...@@ -196,22 +219,22 @@ template<> ...@@ -196,22 +219,22 @@ template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if (a.num_splits <= 8) {{ if (a.num_splits <= 8) {{
kernel_runner<3>::run(s, a); instance<3>::run(s, a);
}} else if (a.num_splits <= 16) {{ }} else if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a); instance<4>::run(s, a);
}} else if (a.num_splits <= 32) {{ }} else if (a.num_splits <= 32) {{
kernel_runner<5>::run(s, a); instance<5>::run(s, a);
}} else if (a.num_splits <= 64) {{ }} else if (a.num_splits <= 64) {{
kernel_runner<6>::run(s, a); instance<6>::run(s, a);
}} else if (a.num_splits <= 128) {{ }} else if (a.num_splits <= 128) {{
kernel_runner<7>::run(s, a); instance<7>::run(s, a);
}} }}
}} }}
template<> template<>
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>() std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>()
{{ {{
using k_ = kernel_runner<6>::fmha_kernel; /// FIXME: choose real kernel type using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type
return k_::GetName(); return k_::GetName();
}} }}
""" """
......
...@@ -510,8 +510,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -510,8 +510,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
} }
}(); }();
dim3 grids = dim3 grids = Kernel::GridSize(
Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits);
return ck_tile::make_tuple(kargs, grids); return ck_tile::make_tuple(kargs, grids);
} }
......
...@@ -23,6 +23,10 @@ def get_if_str(idx, total, lase_else = True): ...@@ -23,6 +23,10 @@ def get_if_str(idx, total, lase_else = True):
else: else:
return 'else if' return 'else if'
XBIAS_ENUM_STR_MAP = [
'no',
'xbias'] # pre-norm add bias
FUSED_ADD_ENUM_STR_MAP = [ FUSED_ADD_ENUM_STR_MAP = [
'no', 'no',
'pras', # pre-norm 'pras', # pre-norm
...@@ -60,6 +64,7 @@ template <typename XDataType_, ...@@ -60,6 +64,7 @@ template <typename XDataType_,
bool kFastFDiv_, bool kFastFDiv_,
bool kWelford_, bool kWelford_,
bool kTwoPass_, bool kTwoPass_,
ck_tile::index_t kXbias_ = 0,
ck_tile::index_t kFusedAdd_ = 0, ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0> ck_tile::index_t kFusedQuant_ = 0>
struct layernorm2d_fwd_traits_ struct layernorm2d_fwd_traits_
...@@ -123,6 +128,7 @@ struct layernorm2d_fwd_traits_ ...@@ -123,6 +128,7 @@ struct layernorm2d_fwd_traits_
static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_; static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_; static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kXbias = kXbias_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
}; };
...@@ -141,6 +147,7 @@ template <typename XDataType_, ...@@ -141,6 +147,7 @@ template <typename XDataType_,
bool kFastFDiv_, bool kFastFDiv_,
bool kWelford_, bool kWelford_,
bool kTwoPass_, bool kTwoPass_,
int kXbias_,
int kFusedAdd_, int kFusedAdd_,
int kFusedQuant_> int kFusedQuant_>
using traits_ = layernorm2d_fwd_traits_<XDataType_, using traits_ = layernorm2d_fwd_traits_<XDataType_,
...@@ -157,6 +164,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_, ...@@ -157,6 +164,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
kFastFDiv_, kFastFDiv_,
kWelford_, kWelford_,
kTwoPass_, kTwoPass_,
kXbias_,
kFusedAdd_, kFusedAdd_,
kFusedQuant_>; kFusedQuant_>;
""" """
...@@ -190,10 +198,12 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -190,10 +198,12 @@ float layernorm2d_fwd_(const S& s, A a)
Traits_::kFastFDiv, Traits_::kFastFDiv,
Traits_::kWelford, Traits_::kWelford,
Traits_::kTwoPass, Traits_::kTwoPass,
static_cast<ck_tile::Layernorm2dXBiasEnum>(Traits_::kXbias),
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd), static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>; static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType, typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XBiasDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType, typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType, typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType, typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType,
...@@ -280,7 +290,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -280,7 +290,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
#include "layernorm2d_fwd_api_common.hpp" #include "layernorm2d_fwd_api_common.hpp"
// clang-format off // clang-format off
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p add sweep // prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p xbias add sweep
{F_instance_def} {F_instance_def}
// clang-format on // clang-format on
...@@ -290,6 +300,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -290,6 +300,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
self.working_path = working_path self.working_path = working_path
self.kernel_filter = kernel_filter self.kernel_filter = kernel_filter
class k_xbias_enum(IntEnum):
F_NO_XBIAS = 0
F_ADD_XBIAS = 1
class k_fuesd_add_enum(IntEnum): class k_fuesd_add_enum(IntEnum):
F_NO_ADD = 0 F_NO_ADD = 0
F_PRE_ADD = 1 F_PRE_ADD = 1
...@@ -305,6 +319,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -305,6 +319,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
F_kPadN : bool F_kPadN : bool
F_kSaveMeanInvStd : bool F_kSaveMeanInvStd : bool
F_kTwoPass : bool F_kTwoPass : bool
F_kXbias : Any #: layernorm_fwd_codegen.k_bias_enum
F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum
F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum
...@@ -321,6 +336,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -321,6 +336,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@dataclass @dataclass
class k_problem: class k_problem:
F_XDataType : str F_XDataType : str
F_XBiasDataType : str
F_GammaDataType : str F_GammaDataType : str
F_BetaDataType : str F_BetaDataType : str
F_ComputeDataType : str F_ComputeDataType : str
...@@ -370,6 +386,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -370,6 +386,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
F_kFastFDiv_ : bool F_kFastFDiv_ : bool
F_kWelford_ : bool F_kWelford_ : bool
F_kTwoPass_ : bool F_kTwoPass_ : bool
F_kXbias_ : int
F_kFusedAdd : int F_kFusedAdd : int
F_kFusedQuant : int F_kFusedQuant : int
...@@ -377,7 +394,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -377,7 +394,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
def trait_name(self) ->str: def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}' t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
return t_ return t_
# string when calling this kernel # string when calling this kernel
...@@ -395,6 +412,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -395,6 +412,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
class h_instance: class h_instance:
F_DataTypePair : str F_DataTypePair : str
F_N : str F_N : str
F_xbias : int
F_add : int F_add : int
F_sweep : int F_sweep : int
instance_list : List[Any] # List[h_traits] instance_list : List[Any] # List[h_traits]
...@@ -404,6 +422,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -404,6 +422,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
prec_i, prec_o = self.F_DataTypePair.split(',') prec_i, prec_o = self.F_DataTypePair.split(',')
dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}'
nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}' nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}'
if self.F_xbias != 0:
nnn = nnn + '_' + XBIAS_ENUM_STR_MAP[self.F_xbias]
if self.F_add != 0: if self.F_add != 0:
nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add]
if self.F_sweep != 0: if self.F_sweep != 0:
...@@ -462,8 +482,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -462,8 +482,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
elif ins.F_kFusedQuant == 2: elif ins.F_kFusedQuant == 2:
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( _cond = '((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, f_vec_n = ins.F_Vector_N, f_xbias = ins.F_kXbias, f_fused_add = ins.F_kFusedAdd,
f_sweep_cond = _sweep_cond) f_sweep_cond = _sweep_cond)
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
F_VEC_COND = _cond, F_instance_func=ins.call_name) F_VEC_COND = _cond, F_instance_func=ins.call_name)
...@@ -494,62 +514,63 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -494,62 +514,63 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
types_16bit = ('int16', 'fp16', 'bf16') types_16bit = ('int16', 'fp16', 'bf16')
#fused_add_list = [0, 1, 2] #fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
xbias_list = [0, 1]
fused_add_list = [0, 1] fused_add_list = [0, 1]
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv fdiv welford 2p add sweep # rm rn tm tn vn pd mv fdiv welford 2p xbias add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0), h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0), '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0), '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0), '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0), '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0), '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0), '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0), '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0), '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0), '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0), '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0), '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0), 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0)]} h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]}
total_blob = list() total_blob = list()
for hs_key in h_trait_dict: for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key] hs = h_trait_dict[hs_key]
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list): for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list):
prec_i, prec_o = dtype.split(',') prec_i, prec_o = dtype.split(',')
scale_x, scale_y = scale_type.split(',') scale_x, scale_y = scale_type.split(',')
if prec_o in dynamic_quant_out_dtype and fused_quant != 1: if prec_o in dynamic_quant_out_dtype and fused_quant != 1:
...@@ -563,6 +584,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -563,6 +584,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_.F_YDataType = prec_o h_.F_YDataType = prec_o
h_.F_XScaleDataType = scale_y h_.F_XScaleDataType = scale_y
h_.F_YScaleDataType = scale_x h_.F_YScaleDataType = scale_x
h_.F_kXbias = xbias
h_.F_kFusedAdd = fused_add h_.F_kFusedAdd = fused_add
h_.F_kFusedQuant = fused_quant h_.F_kFusedQuant = fused_quant
# disable welford update for 8bit and 16 bit smallN # disable welford update for 8bit and 16 bit smallN
...@@ -579,7 +601,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -579,7 +601,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
current_hs.append(h_) # + "\n" current_hs.append(h_) # + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str = 'big' if hs_key == 'big' else current_n current_n_str = 'big' if hs_key == 'big' else current_n
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) total_blob.append(h_instance(dtype, current_n_str, xbias, fused_add, fused_quant, current_hs))
return total_blob return total_blob
def list_blobs(self, args) -> None: def list_blobs(self, args) -> None:
......
...@@ -41,6 +41,7 @@ auto create_args(int argc, char* argv[]) ...@@ -41,6 +41,7 @@ auto create_args(int argc, char* argv[])
.insert("prec_sy", .insert("prec_sy",
"auto", "auto",
"output quant scale type, set auto will use fp32. used when fquant=1 or 2") "output quant scale type, set auto will use fp32. used when fquant=1 or 2")
.insert("xbias", "0", "add bias, 0:no add, 1:add bias before fadd")
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
...@@ -93,6 +94,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -93,6 +94,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
int xbias = arg_parser.get_int("xbias");
int fused_add = arg_parser.get_int("fadd"); int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant"); int fused_quant = arg_parser.get_int("fquant");
if(fused_quant == 1 && prec_o != "int8") if(fused_quant == 1 && prec_o != "int8")
...@@ -107,6 +109,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -107,6 +109,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType; using YDataType = typename TypeConfig::YDataType;
using XBiasDataType = typename TypeConfig::XBiasDataType;
using GammaDataType = typename TypeConfig::GammaDataType; using GammaDataType = typename TypeConfig::GammaDataType;
using BetaDataType = typename TypeConfig::BetaDataType; using BetaDataType = typename TypeConfig::BetaDataType;
using XResidualDataType = XDataType; using XResidualDataType = XDataType;
...@@ -121,6 +124,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -121,6 +124,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XBiasDataType> x_bias_host({n});
ck_tile::HostTensor<GammaDataType> gamma_host({n}); ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({n}); ck_tile::HostTensor<BetaDataType> beta_host({n});
...@@ -141,10 +145,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -141,10 +145,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host); ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host); ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
ck_tile::FillUniformDistribution<XBiasDataType>{-.5f, .5f}(x_bias_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host); ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_bias_buf(x_bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
...@@ -155,6 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -155,6 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
x_bias_buf.ToDevice(x_bias_host.data());
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
x_residual_buf.ToDevice(x_residual_host.data()); x_residual_buf.ToDevice(x_residual_host.data());
...@@ -179,11 +186,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -179,11 +186,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< ", yr_stride:" << yr_stride << std::flush; << ", yr_stride:" << yr_stride << std::flush;
layernorm2d_fwd_traits traits{ layernorm2d_fwd_traits traits{
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr,
x_bias_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(),
...@@ -210,8 +218,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -210,8 +218,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false; return false;
} }
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XBiasDataType) * n +
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; sizeof(GammaDataType) * n + sizeof(BetaDataType) * n +
sizeof(YDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
...@@ -221,6 +230,22 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -221,6 +230,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
// reference // reference
if(xbias != 0)
{
// add bias before fadd
int M = x_host.mDesc.get_lengths()[0];
int N = x_host.mDesc.get_lengths()[1];
for(int idx_m = 0; idx_m < M; ++idx_m)
{
for(int idx_n = 0; idx_n < N; ++idx_n)
{
x_host(idx_m, idx_n) = ck_tile::type_convert<XDataType>(
ck_tile::type_convert<ComputeDataType>(x_host(idx_m, idx_n)) +
ck_tile::type_convert<ComputeDataType>(x_bias_host(idx_n)));
}
}
}
if(fused_add != 0) if(fused_add != 0)
{ {
// fused pre_add/pre_add_store // fused pre_add/pre_add_store
......
...@@ -16,6 +16,7 @@ struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleData ...@@ -16,6 +16,7 @@ struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleData
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using YDataType = OutType; using YDataType = OutType;
using XBiasDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t; using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t; using BetaDataType = ck_tile::half_t;
using MeanDataType = ck_tile::half_t; using MeanDataType = ck_tile::half_t;
...@@ -30,6 +31,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleData ...@@ -30,6 +31,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleData
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using YDataType = OutType; using YDataType = OutType;
using XBiasDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t; using GammaDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t; using BetaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t; using MeanDataType = ck_tile::bf16_t;
...@@ -57,6 +59,7 @@ struct layernorm2d_fwd_traits ...@@ -57,6 +59,7 @@ struct layernorm2d_fwd_traits
std::string prec_sy; // y-scale, used for [M*1] output for next layer std::string prec_sy; // y-scale, used for [M*1] output for next layer
bool save_mean_var; // bool save_mean_var; //
int xbias; // 0:no-bias, 1:add bias
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
}; };
......
...@@ -97,6 +97,10 @@ ...@@ -97,6 +97,10 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ #cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif #endif
#ifndef CK_ENABLE_DPP_KERNELS
#cmakedefine CK_ENABLE_DPP_KERNELS @CK_ENABLE_DPP_KERNELS@
#endif
// //
// CK kernels which support XDL (MI series) // CK kernels which support XDL (MI series)
// //
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment