Commit 9e063018 authored by carlushuang's avatar carlushuang
Browse files

dynamic-quant ready

parent e2935465
...@@ -30,8 +30,7 @@ FUSED_ADD_ENUM_STR_MAP = [ ...@@ -30,8 +30,7 @@ FUSED_ADD_ENUM_STR_MAP = [
FUSED_FUSED_SWEEP_STR_MAP = [ FUSED_FUSED_SWEEP_STR_MAP = [
'no', 'no',
'renorm', 'dquant' ]
'dequant' ]
DATA_TYPE_MAP = {'fp16' : 'ck_tile::fp16_t', DATA_TYPE_MAP = {'fp16' : 'ck_tile::fp16_t',
'bf16' : 'ck_tile::bf16_t', 'bf16' : 'ck_tile::bf16_t',
...@@ -48,6 +47,7 @@ class layernorm_fwd_codegen: ...@@ -48,6 +47,7 @@ class layernorm_fwd_codegen:
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename XDataType_, template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M ck_tile::index_t ThreadPerBlock_M_, // num threads along M
...@@ -62,6 +62,7 @@ struct layernorm2d_fwd_traits_ ...@@ -62,6 +62,7 @@ struct layernorm2d_fwd_traits_
{ {
using XDataType = ck_tile::remove_cvref_t<XDataType_>; using XDataType = ck_tile::remove_cvref_t<XDataType_>;
using YDataType = ck_tile::remove_cvref_t<YDataType_>; using YDataType = ck_tile::remove_cvref_t<YDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
...@@ -121,6 +122,7 @@ struct layernorm2d_fwd_traits_ ...@@ -121,6 +122,7 @@ struct layernorm2d_fwd_traits_
template <typename XDataType_, template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M ck_tile::index_t ThreadPerBlock_M_, // num threads along M
...@@ -133,6 +135,7 @@ template <typename XDataType_, ...@@ -133,6 +135,7 @@ template <typename XDataType_,
int kFusedSweep_> int kFusedSweep_>
using traits_ = layernorm2d_fwd_traits_<XDataType_, using traits_ = layernorm2d_fwd_traits_<XDataType_,
YDataType_, YDataType_,
YScaleDataType_,
Repeat_M_, Repeat_M_,
Repeat_N_, Repeat_N_,
ThreadPerBlock_M_, ThreadPerBlock_M_,
...@@ -165,7 +168,8 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -165,7 +168,8 @@ float layernorm2d_fwd_(const S& s, A a)
{{ {{
using XDataType = typename Traits_::XDataType; using XDataType = typename Traits_::XDataType;
using YDataType = typename Traits_::YDataType; using YDataType = typename Traits_::YDataType;
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType>::ComputeDataType; using YScaleDataType = typename Traits_::YScaleDataType;
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::ComputeDataType;
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN, using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd, Traits_::kSaveMeanInvStd,
...@@ -173,13 +177,14 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -173,13 +177,14 @@ float layernorm2d_fwd_(const S& s, A a)
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd), static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedSweepEnum>(Traits_::kFusedSweep)>; static_cast<ck_tile::Layernorm2dFusedSweepEnum>(Traits_::kFusedSweep)>;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<XDataType, YDataType>::XDataType, typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::XDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::GammaDataType, typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::GammaDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::BetaDataType, typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::BetaDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::ComputeDataType, typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::ComputeDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::YDataType, typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::YDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::MeanDataType, typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::MeanDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::InvStdDataType, typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::InvStdDataType,
typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::YScaleDataType,
typename Traits_::Shape, typename Traits_::Shape,
PipelineTraits>; PipelineTraits>;
...@@ -190,7 +195,12 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -190,7 +195,12 @@ float layernorm2d_fwd_(const S& s, A a)
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>; using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>; using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
using Epilogue = Default2DEpilogue; using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, YScaleDataType, YDataType,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, false, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
using Epilogue = std::conditional_t<Traits_::kFusedSweep == 1, DynamicQuantEpilogue, Default2DEpilogue>;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>; using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
...@@ -247,7 +257,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -247,7 +257,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
#include "layernorm2d_fwd_api_common.hpp" #include "layernorm2d_fwd_api_common.hpp"
// clang-format off // clang-format off
// prec_i prec_o rm rn tm tn vn pd mv 2p add sweep // prec_i prec_o prec_s rm rn tm tn vn pd mv 2p add sweep
{F_instance_def} {F_instance_def}
// clang-format on // clang-format on
...@@ -325,6 +335,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -325,6 +335,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
class h_traits: class h_traits:
F_XDataType : str F_XDataType : str
F_YDataType : str F_YDataType : str
F_YScaleDataType : str
F_Repeat_M : int F_Repeat_M : int
F_Repeat_N : int F_Repeat_N : int
F_ThreadPerBlock_M : int F_ThreadPerBlock_M : int
...@@ -338,7 +349,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -338,7 +349,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@property @property
def trait_name(self) ->str: def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}' t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedSweep:4}' t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedSweep:4}'
return t_ return t_
...@@ -424,7 +435,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -424,7 +435,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
F_VEC_COND = _cond, F_instance_func=ins.call_name) F_VEC_COND = _cond, F_instance_func=ins.call_name)
#inner_str = inner_str + vec_str #inner_str = inner_str + vec_str
n_cnd = f'(a.n <= {n_})' if n_ != 'big' else '' n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else ''
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
prec_i, prec_o = dtype_.split(',') prec_i, prec_o = dtype_.split(',')
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
...@@ -440,61 +451,76 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -440,61 +451,76 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_traits = layernorm_fwd_codegen.h_traits h_traits = layernorm_fwd_codegen.h_traits
h_instance = layernorm_fwd_codegen.h_instance h_instance = layernorm_fwd_codegen.h_instance
dynamic_quant_out_dtype = ['int8']
# some predefined support range # some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict # (prec_i,prec_o) for simplicity this string will be used as key for dict
dtype_list = [('fp16,fp16'), ('bf16,bf16')] dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
fused_add_list = [0, 1, 2] fused_add_list = [0, 1, 2]
fused_sweep_list = [0] fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv 2p add sweep # rm rn tm tn vn pd mv 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 1, 1, 4, 64, 1, True, False, False, 0, 0)], h_trait_dict = {'64' : [ h_traits('x', 'y', 's', 1, 1, 4, 64, 1, True, False, False, 0, 0)],
'128' : [ h_traits('x', 'y', 1, 1, 4, 64, 2, True, False, False, 0, 0), '128' : [ h_traits('x', 'y', 's', 1, 1, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 's', 1, 2, 4, 64, 1, True, False, False, 0, 0)],
'256' : [ h_traits('x', 'y', 1, 1, 4, 64, 4, True, False, False, 0, 0), '256' : [ h_traits('x', 'y', 's', 1, 1, 4, 64, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 2, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 's', 1, 4, 4, 64, 1, True, False, False, 0, 0)],
'512' : [ h_traits('x', 'y', 1, 1, 4, 64, 8, True, False, False, 0, 0), '512' : [ h_traits('x', 'y', 's', 1, 1, 4, 64, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 4, 64, 4, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 2, 4, 64, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 4, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 8, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 's', 1, 8, 4, 64, 1, True, False, False, 0, 0)],
'768' : [ h_traits('x', 'y', 1, 3, 4, 64, 4, True, False, False, 0, 0), '768' : [ h_traits('x', 'y', 's', 1, 3, 4, 64, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 6, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 6, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 12, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 's', 1, 12, 4, 64, 1, True, False, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 1, 1, 2, 128, 8, True, False, False, 0, 0), '1024' :[ h_traits('x', 'y', 's', 1, 1, 2, 128, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 2, 128, 4, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 2, 2, 128, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 2, 128, 2, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 4, 2, 128, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 1, 256, 1, True, False, False, 0, 0)], h_traits('x', 'y', 's', 1, 4, 1, 256, 1, True, False, False, 0, 0)],
'1536' :[ h_traits('x', 'y', 1, 3, 4, 64, 8, True, False, False, 0, 0), '1536' :[ h_traits('x', 'y', 's', 1, 3, 4, 64, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 3, 2, 128, 4, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 3, 2, 128, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 3, 1, 256, 2, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 3, 1, 256, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 6, 1, 256, 1, True, False, False, 0, 0)], h_traits('x', 'y', 's', 1, 6, 1, 256, 1, True, False, False, 0, 0)],
'2048' :[ h_traits('x', 'y', 1, 1, 1, 256, 8, True, False, False, 0, 0), '2048' :[ h_traits('x', 'y', 's', 1, 1, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 1, 256, 4, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 2, 1, 256, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 1, 256, 2, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 4, 1, 256, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 8, 1, 256, 1, True, False, False, 0, 0)], h_traits('x', 'y', 's', 1, 8, 1, 256, 1, True, False, False, 0, 0)],
'3072' :[ h_traits('x', 'y', 1, 3, 1, 128, 8, True, False, False, 0, 0), '3072' :[ h_traits('x', 'y', 's', 1, 3, 1, 128, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 3, 1, 256, 4, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 3, 1, 256, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 6, 1, 256, 2, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 6, 1, 256, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 3, 1,1024, 1, True, False, False, 0, 0)], h_traits('x', 'y', 's', 1, 3, 1,1024, 1, True, False, False, 0, 0)],
'4096' :[ h_traits('x', 'y', 1, 1, 1, 512, 8, True, False, False, 0, 0), '4096' :[ h_traits('x', 'y', 's', 1, 1, 1, 512, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 1, 256, 4, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 4, 1, 256, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 's', 1, 2, 1,1024, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 1,1024, 1, True, False, False, 0, 0)], h_traits('x', 'y', 's', 1, 4, 1,1024, 1, True, False, False, 0, 0)],
'big' :[ h_traits('x', 'y', 1, 2, 1, 256, 8, True, False, True, 0, 0), '6144' :[ h_traits('x', 'y', 's', 1, 3, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 1, 256, 4, True, False, True, 0, 0), h_traits('x', 'y', 's', 1, 3, 1, 512, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 1,1024, 2, True, False, True, 0, 0), h_traits('x', 'y', 's', 1, 3, 1,1024, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} h_traits('x', 'y', 's', 1, 6, 1,1024, 1, True, False, False, 0, 0)],
'8192' :[ h_traits('x', 'y', 's', 1, 4, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 's', 1, 4, 1, 512, 4, True, False, False, 0, 0),
h_traits('x', 'y', 's', 1, 4, 1,1024, 2, True, False, False, 0, 0),
h_traits('x', 'y', 's', 1, 8, 1,1024, 1, True, False, False, 0, 0)],
'big' :[ h_traits('x', 'y', 's', 1, 2, 1, 256, 8, True, False, True, 0, 0),
h_traits('x', 'y', 's', 1, 4, 1, 256, 4, True, False, True, 0, 0),
h_traits('x', 'y', 's', 1, 2, 1,1024, 2, True, False, True, 0, 0),
h_traits('x', 'y', 's', 1, 4, 1,1024, 1, True, False, True, 0, 0)]}
total_blob = list() total_blob = list()
for hs_key in h_trait_dict: for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key] hs = h_trait_dict[hs_key]
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
for dtype, fused_add, fused_sweep in itertools.product(dtype_list, fused_add_list, fused_sweep_list): for dtype, fused_add, fused_sweep in itertools.product(dtype_list, fused_add_list, fused_sweep_list):
prec_i, prec_o = dtype.split(',') prec_i, prec_o = dtype.split(',')
if prec_o in dynamic_quant_out_dtype and fused_sweep != 1:
continue # skip non dynamic quant case
if fused_sweep == 1 and hs_key == 'big':
continue
current_hs = list() current_hs = list()
for chs_ in hs: for chs_ in hs:
h_ = copy.copy(chs_) # copy the base instance out h_ = copy.copy(chs_) # copy the base instance out
h_.F_XDataType = prec_i h_.F_XDataType = prec_i
h_.F_YDataType = prec_o h_.F_YDataType = prec_o
h_.F_YScaleDataType = prec_i
h_.F_kFusedAdd = fused_add h_.F_kFusedAdd = fused_add
h_.F_kFusedSweep = fused_sweep h_.F_kFusedSweep = fused_sweep
current_hs.append(h_) # + "\n" current_hs.append(h_) # + "\n"
......
...@@ -32,8 +32,11 @@ auto create_args(int argc, char* argv[]) ...@@ -32,8 +32,11 @@ auto create_args(int argc, char* argv[])
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec_i", "fp16", "input precision") .insert("prec_i", "fp16", "input precision")
.insert("prec_o", "auto", "output precision, set auto will be the same as input") .insert("prec_o", "auto", "output precision, set auto will be the same as input")
.insert("prec_s",
"auto",
"output quant scale type, set auto will be the same as input. used when fsweep=1")
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
.insert("fsweep", "0", "fused-sweep") .insert("fsweep", "0", "fused-sweep, 0:no, 1:fused-dynamic-quant")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter"); .insert("repeat", "20", "hot iter");
...@@ -41,7 +44,7 @@ auto create_args(int argc, char* argv[]) ...@@ -41,7 +44,7 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
template <typename InDataType, typename OutDataType, bool SaveMeanVar> template <typename InDataType, typename OutDataType, typename ScaleDataType, bool SaveMeanVar>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
...@@ -52,27 +55,38 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -52,27 +55,38 @@ bool run(const ck_tile::ArgParser& arg_parser)
float epsilon = arg_parser.get_float("e"); float epsilon = arg_parser.get_float("e");
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_s = arg_parser.get_str("prec_s");
if(prec_o == "auto") if(prec_o == "auto")
{ {
prec_o = prec_i; prec_o = prec_i;
} }
if(prec_s == "auto")
{
prec_s = prec_i;
}
int kname = arg_parser.get_int("kname"); int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
int fused_add = arg_parser.get_int("fadd"); int fused_add = arg_parser.get_int("fadd");
int fused_sweep = arg_parser.get_int("fsweep"); int fused_sweep = arg_parser.get_int("fsweep");
if(fused_sweep == 1 && prec_o != "int8")
{
std::cout << "if fused_sweep is 1, only support \"-prec_o=int8\" case" << std::endl;
return false;
}
assert(stride >= n); assert(stride >= n);
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType>; using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, ScaleDataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType; using YDataType = typename TypeConfig::YDataType;
using GammaDataType = typename TypeConfig::GammaDataType; using GammaDataType = typename TypeConfig::GammaDataType;
using BetaDataType = typename TypeConfig::BetaDataType; using BetaDataType = typename TypeConfig::BetaDataType;
using SXDataType = XDataType; using SXDataType = XDataType;
using SYDataType = YDataType; using SYDataType = XDataType;
using MeanDataType = using MeanDataType =
std::conditional_t<SaveMeanVar, typename TypeConfig::MeanDataType, ck_tile::null_type>; std::conditional_t<SaveMeanVar, typename TypeConfig::MeanDataType, ck_tile::null_type>;
...@@ -94,6 +108,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -94,6 +108,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<MeanDataType> mean_host_ref({m}); ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m}); ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
ck_tile::HostTensor<ScaleDataType> y_scale_host_ref({m});
ck_tile::HostTensor<ScaleDataType> y_scale_host_dev({m});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
...@@ -103,6 +119,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -103,6 +119,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem sx_buf(sx_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sx_buf(sx_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sy_buf(sy_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sy_buf(sy_host.get_element_space_size_in_bytes());
...@@ -112,10 +129,23 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -112,10 +129,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
sx_buf.ToDevice(sx_host.data()); sx_buf.ToDevice(sx_host.data());
std::cout << "[" << prec_i << "]" auto prec_str = [&]() {
auto base_str = prec_i;
if(prec_i != prec_o)
{
base_str += "|" + prec_o;
}
if(fused_sweep == 1)
{
base_str += std::string("(") + prec_s + ")";
}
return base_str;
}();
std::cout << "[" << prec_str << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
layernorm2d_fwd_traits traits{prec_i, prec_o, SaveMeanVar, fused_add, fused_sweep}; layernorm2d_fwd_traits traits{prec_i, prec_o, prec_s, SaveMeanVar, fused_add, fused_sweep};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? sx_buf.GetDeviceBuffer() : nullptr, fused_add != 0 ? sx_buf.GetDeviceBuffer() : nullptr,
...@@ -125,6 +155,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -125,6 +155,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_add == 1 ? sy_buf.GetDeviceBuffer() : nullptr, fused_add == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
nullptr, nullptr,
nullptr, nullptr,
fused_sweep == 1 ? y_scale_buf.GetDeviceBuffer() : nullptr,
epsilon, epsilon,
m, m,
n, n,
...@@ -170,6 +201,50 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -170,6 +201,50 @@ bool run(const ck_tile::ArgParser& arg_parser)
InvStdDataType>( InvStdDataType>(
x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
if(fused_sweep == 1)
{
auto dquant_functor = [&](int m_, auto o_, auto acc_) {
int N_ = acc_.mDesc.get_lengths()[1];
ComputeDataType absmax = 0;
for(int n_ = 0; n_ < N_; n_++)
{
const auto a = abs(acc_(m_, n_));
absmax = a > absmax ? a : absmax;
}
y_scale_host_ref(m_) = absmax / 127.0;
for(int n_ = 0; n_ < N_; n_++)
{
o_(m_, n_) = static_cast<YDataType>(acc_(m_, n_) / y_scale_host_ref(m_));
}
};
ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(x_host,
gamma_host,
beta_host,
y_host_ref,
mean_host_ref,
invStd_host_ref,
epsilon,
dquant_functor);
}
else
{
ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(
x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
}
y_buf.FromDevice(y_host_dev.data()); y_buf.FromDevice(y_host_dev.data());
ck_tile::HostTensor<SYDataType> sy_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<SYDataType> sy_host_dev({m, n}, {stride, 1});
...@@ -179,6 +254,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -179,6 +254,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
auto [rtol, atol] = get_elimit<InDataType>(); auto [rtol, atol] = get_elimit<InDataType>();
if(stride == n) if(stride == n)
{ {
pass = ck_tile::check_err( pass = ck_tile::check_err(
...@@ -218,6 +294,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -218,6 +294,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
} }
} }
if(fused_sweep == 1)
{
y_scale_buf.FromDevice(y_scale_host_dev.data());
pass &= ck_tile::check_err(y_scale_host_dev,
y_scale_host_ref,
std::string("SCALE Error: Incorrect results!"),
rtol,
atol);
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
} }
...@@ -233,26 +318,44 @@ int main(int argc, char* argv[]) ...@@ -233,26 +318,44 @@ int main(int argc, char* argv[])
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_s = arg_parser.get_str("prec_s");
if(prec_o == "auto") if(prec_o == "auto")
{ {
prec_o = prec_i; prec_o = prec_i;
} }
if(prec_s == "auto")
{
prec_s = prec_i;
}
int save_mv = arg_parser.get_int("save_mv"); int save_mv = arg_parser.get_int("save_mv");
if(prec_i == "fp16" && prec_o == "fp16" && save_mv)
// no dynamic quant case
if(prec_i == "fp16" && prec_o == "fp16" && prec_s == "fp16" && save_mv)
{
return run<ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "fp16" && prec_o == "fp16" && prec_s == "fp16" && !save_mv)
{
return run<ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_s == "bf16" && save_mv)
{ {
return run<ck_tile::half_t, ck_tile::half_t, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "fp16" && prec_o == "fp16" && !save_mv) else if(prec_i == "bf16" && prec_o == "bf16" && prec_s == "bf16" && !save_mv)
{ {
return run<ck_tile::half_t, ck_tile::half_t, false>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "bf16" && prec_o == "bf16" && save_mv)
// dynamic quant case, only in inference
else if(prec_i == "fp16" && prec_o == "int8" && prec_s == "fp16" && !save_mv)
{ {
return run<ck_tile::bf16_t, ck_tile::bf16_t, true>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::int8_t, ck_tile::half_t, false>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "bf16" && prec_o == "bf16" && !save_mv) else if(prec_i == "bf16" && prec_o == "int8" && prec_s == "bf16" && !save_mv)
{ {
return run<ck_tile::bf16_t, ck_tile::bf16_t, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::int8_t, ck_tile::bf16_t, false>(arg_parser) ? 0 : -2;
} }
return -3; return -3;
......
...@@ -8,11 +8,11 @@ ...@@ -8,11 +8,11 @@
#include "ck_tile/ops/layernorm2d.hpp" #include "ck_tile/ops/layernorm2d.hpp"
#include <string> #include <string>
template <typename InType, typename OutType> template <typename InType, typename OutType, typename YScaleDataType_>
struct LayerNormTypeConfig; struct LayerNormTypeConfig;
template <typename OutType> template <typename OutType, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::half_t, OutType> struct LayerNormTypeConfig<ck_tile::half_t, OutType, YScaleDataType_>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using YDataType = OutType; using YDataType = OutType;
...@@ -21,10 +21,11 @@ struct LayerNormTypeConfig<ck_tile::half_t, OutType> ...@@ -21,10 +21,11 @@ struct LayerNormTypeConfig<ck_tile::half_t, OutType>
using MeanDataType = ck_tile::half_t; using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t; using InvStdDataType = ck_tile::half_t;
using ComputeDataType = float; using ComputeDataType = float;
using YScaleDataType = YScaleDataType_;
}; };
template <typename OutType> template <typename OutType, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType> struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, YScaleDataType_>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using YDataType = OutType; using YDataType = OutType;
...@@ -33,6 +34,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t, OutType> ...@@ -33,6 +34,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t, OutType>
using MeanDataType = ck_tile::bf16_t; using MeanDataType = ck_tile::bf16_t;
using InvStdDataType = ck_tile::bf16_t; using InvStdDataType = ck_tile::bf16_t;
using ComputeDataType = float; using ComputeDataType = float;
using YScaleDataType = YScaleDataType_;
}; };
// runtime args // runtime args
...@@ -48,9 +50,10 @@ struct layernorm2d_fwd_traits ...@@ -48,9 +50,10 @@ struct layernorm2d_fwd_traits
{ {
std::string prec_i; std::string prec_i;
std::string prec_o; std::string prec_o;
std::string prec_s; // scale value, used as scale factor store out when fused_sweep=1
bool save_mean_var; bool save_mean_var;
int fused_add; // 0:no-add, 1:pre-add, 2:post-add int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
int fused_sweep; // 0:no-sweep, int fused_sweep; // 0:no-sweep, 1:dynamic-quant
}; };
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
...@@ -2,32 +2,34 @@ ...@@ -2,32 +2,34 @@
# call from top of CK folder # call from top of CK folder
EXE=./build/bin/tile_example_layernorm2d_fwd EXE=./build/bin/tile_example_layernorm2d_fwd
for fsweep in "" "-fsweep=1 -prec_o=int8"; do
for pr_i in "fp16" "bf16" ; do for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1" "2"; do for fadd in "0" "1" "2"; do
$EXE -prec_i=$pr_i -fadd=$fadd -m=99 -n=13 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=99 -n=13
$EXE -prec_i=$pr_i -fadd=$fadd -m=17 -n=16 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=17 -n=16
$EXE -prec_i=$pr_i -fadd=$fadd -m=1 -n=100 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=1 -n=100
$EXE -prec_i=$pr_i -fadd=$fadd -m=4 -n=128 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=4 -n=128
$EXE -prec_i=$pr_i -fadd=$fadd -m=80 -n=127 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=80 -n=127
$EXE -prec_i=$pr_i -fadd=$fadd -m=22 -n=255 -stride=256 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=22 -n=255 -stride=256
$EXE -prec_i=$pr_i -fadd=$fadd -m=7 -n=599 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=7 -n=599
$EXE -prec_i=$pr_i -fadd=$fadd -m=19 -n=512 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=19 -n=512
$EXE -prec_i=$pr_i -fadd=$fadd -m=33 -n=313 -stride=1000 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=33 -n=313 -stride=1000
$EXE -prec_i=$pr_i -fadd=$fadd -m=11 -n=510 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=11 -n=510
$EXE -prec_i=$pr_i -fadd=$fadd -m=171 -n=676 -stride=818 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=171 -n=676 -stride=818
$EXE -prec_i=$pr_i -fadd=$fadd -m=91 -n=636 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=91 -n=636
$EXE -prec_i=$pr_i -fadd=$fadd -m=12 -n=768 -stride=800 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=12 -n=768 -stride=800
$EXE -prec_i=$pr_i -fadd=$fadd -m=100 -n=766 -stride=812 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=100 -n=766 -stride=812
$EXE -prec_i=$pr_i -fadd=$fadd -m=31 -n=1024 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=31 -n=1024
$EXE -prec_i=$pr_i -fadd=$fadd -m=64 -n=1000 -stride=1004 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=64 -n=1000 -stride=1004
$EXE -prec_i=$pr_i -fadd=$fadd -m=8 -n=1501 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=8 -n=1501
$EXE -prec_i=$pr_i -fadd=$fadd -m=3 -n=1826 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=3 -n=1826
$EXE -prec_i=$pr_i -fadd=$fadd -m=5 -n=2040 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=5 -n=2040
$EXE -prec_i=$pr_i -fadd=$fadd -m=7 -n=2734 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd -m=3 -n=8192 $EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=3 -n=8192
$EXE -prec_i=$pr_i -fadd=$fadd -m=1 -n=10547 #$EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=1 -n=10547
$EXE -prec_i=$pr_i -fadd=$fadd -m=3 -n=17134 #$EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=3 -n=17134
done
done done
done done
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/math.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace ck_tile {
// use int8_t directly for int8 arithemetic
// here one can use ck_tile::int8_t to access original int8_t
using int8_t = int8_t;
// limits
template <class T>
struct numeric;
template <>
struct numeric<int8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr int8_t min() { return int8_t(-128); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr int8_t lowest() { return int8_t(-128); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr int8_t max() { return int8_t(127); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr int8_t epsilon()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr int8_t round_error()
{
return 1; // not used
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr int8_t infinity()
{
return 1; // not used
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr int8_t quiet_NaN()
{
return 1; // not used
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr int8_t signaling_NaN()
{
return 1; // not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr int8_t denorm_min()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr int8_t zero() { return 0; }
};
#if 0
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<int8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#endif
CK_TILE_HOST_DEVICE
constexpr float int8_to_float(const int8_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
constexpr int8_t float_to_int8(const float& x) { return static_cast<int8_t>(x); }
} // namespace ck_tile
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/int8.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -60,6 +61,9 @@ CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float) ...@@ -60,6 +61,9 @@ CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float) CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float) CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
#undef CK_TILE_TYPE_CONVERT #undef CK_TILE_TYPE_CONVERT
#endif #endif
......
...@@ -396,4 +396,59 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -396,4 +396,59 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
return res; return res;
} }
#if 0
// TODO: Note, int8 validation is risky, need more check
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, int8_t>,
bool>::type CK_TILE_HOST
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error[int8]: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3,
bool /*allow_infinity_ref*/ = false)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
// TODO: This is a hack. We should have proper specialization for bf16_t data type.
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
}
return res;
}
#endif
} // namespace ck_tile } // namespace ck_tile
...@@ -8,20 +8,44 @@ ...@@ -8,20 +8,44 @@
namespace ck_tile { namespace ck_tile {
// Note: for simplicity, each functor only care about single M
struct reference_layernorm2d_default_epilogue
{
template <typename OutDataType, typename AccDataType>
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
{
const int N = acc.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
}
}
template <typename OutDataType, typename AccDataType>
auto operator()(int m, const HostTensor<AccDataType>& acc)
{
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
operator()(m, o, acc);
return o;
}
};
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename MeanDataType, typename MeanDataType,
typename InvStdDataType> typename InvStdDataType,
typename Epilogue = reference_layernorm2d_default_epilogue>
void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n, void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
const HostTensor<GammaDataType>& gamma_n, const HostTensor<GammaDataType>& gamma_n,
const HostTensor<BetaDataType>& beta_n, const HostTensor<BetaDataType>& beta_n,
HostTensor<YDataType>& y_m_n, HostTensor<YDataType>& y_m_n,
HostTensor<MeanDataType>& mean_m, HostTensor<MeanDataType>& mean_m,
HostTensor<InvStdDataType>& invStd_m, HostTensor<InvStdDataType>& invStd_m,
ComputeDataType epsilon) ComputeDataType epsilon,
Epilogue epilogue_functor = {})
{ {
auto layernorm2d_fwd_func = [&](auto m) { auto layernorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1]; const int N = x_m_n.mDesc.get_lengths()[1];
...@@ -51,16 +75,19 @@ void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n, ...@@ -51,16 +75,19 @@ void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if constexpr(!std::is_same_v<InvStdDataType, ck_tile::null_type>) if constexpr(!std::is_same_v<InvStdDataType, ck_tile::null_type>)
invStd_m(m) = ck_tile::type_convert<InvStdDataType>(divisor); invStd_m(m) = ck_tile::type_convert<InvStdDataType>(divisor);
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n)); ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n)); ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
ComputeDataType beta = ck_tile::type_convert<ComputeDataType>(beta_n(n)); ComputeDataType beta = ck_tile::type_convert<ComputeDataType>(beta_n(n));
auto y = (x - mean) * divisor; auto a_ = (x - mean) * divisor;
y = y * gamma + beta; a_ = a_ * gamma + beta;
y_m_n(m, n) = ck_tile::type_convert<YDataType>(y); acc(m, n) = a_;
} }
epilogue_functor(m, y_m_n, acc);
}; };
make_ParallelTensorFunctor(layernorm2d_fwd_func, make_ParallelTensorFunctor(layernorm2d_fwd_func,
......
...@@ -5,4 +5,5 @@ ...@@ -5,4 +5,5 @@
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <bool kPadM_, bool kPadN_, bool UseRawStore_ = true, bool UseMax3_ = false>
struct DynamicQuantEpilogueTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
static constexpr bool UseMax3 = UseMax3_;
};
// this epilogue just store out a M*N matrix, row major
template <typename AccDataType_, typename YScaleDataType_, typename ODataType_, typename Traits_>
struct DynamicQuantEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using Traits = remove_cvref_t<Traits_>;
};
template <typename Problem_, typename Policy_ = void>
struct DynamicQuantEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::Traits::kPadM;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool UseRawStore = Problem::Traits::UseRawStore;
static constexpr bool UseMax3 = Problem::Traits::UseMax3;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile)
{
// compute row max
auto reduce_row_absmax = BlockReduce2D{o_acc_tile, type_convert<AccDataType>(0)};
auto row_absmax = [&]() {
if constexpr(UseMax3 && std::is_same_v<AccDataType, float>)
{
const auto f_max = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
// const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
// float rtn;
// asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
// : "=v"(rtn)
// : "v"(acc_), "v"(v_0_), "v"(v_1_));
// return rtn;
// };
// return reduce_row_absmax(f_max3, f_max, sequence<1, 2>{});
return reduce_row_absmax(f_max);
}
else
{
const auto f_max = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
return reduce_row_absmax(f_max);
}
}();
// here y_scale is Acc TYpe, need convert to YScale type later
auto y_scale = tile_elementwise_in(
[&](const auto& v_) {
return v_ / type_convert<AccDataType>(numeric<ODataType>::max());
},
row_absmax);
store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale));
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
};
} // namespace ck_tile
...@@ -21,6 +21,7 @@ struct Layernorm2dFwdHostArgs ...@@ -21,6 +21,7 @@ struct Layernorm2dFwdHostArgs
void* p_sy; // shortcut output, set to nullptr if no void* p_sy; // shortcut output, set to nullptr if no
void* p_mean; void* p_mean;
void* p_invStd; void* p_invStd;
void* p_y_scale; // store out a dynamic quant per row, used by next layer. nullptr if not used
float epsilon; float epsilon;
...@@ -44,6 +45,7 @@ struct Layernorm2dFwd ...@@ -44,6 +45,7 @@ struct Layernorm2dFwd
using YDataType = remove_cvref_t<typename Problem::YDataType>; using YDataType = remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>; using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>; using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
// for simplicity, shortcut input/output type is same as X // for simplicity, shortcut input/output type is same as X
using SXDataType = XDataType; using SXDataType = XDataType;
...@@ -81,6 +83,7 @@ struct Layernorm2dFwd ...@@ -81,6 +83,7 @@ struct Layernorm2dFwd
void* p_sy; // shortcut output, set to nullptr if no void* p_sy; // shortcut output, set to nullptr if no
void* p_mean; void* p_mean;
void* p_invStd; void* p_invStd;
void* p_y_scale; // store out a dynamic quant value, used in next layer
float epsilon; float epsilon;
...@@ -100,6 +103,7 @@ struct Layernorm2dFwd ...@@ -100,6 +103,7 @@ struct Layernorm2dFwd
hargs.p_sy, hargs.p_sy,
hargs.p_mean, hargs.p_mean,
hargs.p_invStd, hargs.p_invStd,
hargs.p_y_scale,
hargs.epsilon, hargs.epsilon,
hargs.m, hargs.m,
hargs.n, hargs.n,
...@@ -120,6 +124,7 @@ struct Layernorm2dFwd ...@@ -120,6 +124,7 @@ struct Layernorm2dFwd
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; }; template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; }; template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; }; template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
// clang-format on // clang-format on
// in byte // in byte
...@@ -140,7 +145,18 @@ struct Layernorm2dFwd ...@@ -140,7 +145,18 @@ struct Layernorm2dFwd
if (kTwoPass) n += "_2p"; if (kTwoPass) n += "_2p";
return n; }(); return n; }();
return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" + auto prec_str = [&] () {
std::string base_str = _SS_(t2s<XDataType>::name);
if (!std::is_same_v<XDataType, YDataType>) {
base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
}
if (kFusedSweep == Layernorm2dFusedSweepEnum::DYNAMIC_QUANT) {
base_str += _SS_("_s") + _SS_(t2s<YScaleDataType>::name);
}
return base_str;
}();
return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix; _SS_(Pipeline::name) + surfix;
...@@ -295,6 +311,24 @@ struct Layernorm2dFwd ...@@ -295,6 +311,24 @@ struct Layernorm2dFwd
return make_null_tile_window(make_tuple(number<Block_M>{})); return make_null_tile_window(make_tuple(number<Block_M>{}));
}(); }();
auto y_scale_window = [&]() {
if constexpr(kFusedSweep == Layernorm2dFusedSweepEnum::DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<YScaleDataType*>(kargs.p_y_scale),
make_tuple(kargs.m),
number<1>{});
return pad_tensor_view(
tmp_0_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}();
return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<Block_M>{}));
}();
__shared__ char smem[GetSmemSize()]; __shared__ char smem[GetSmemSize()];
Pipeline{}(x_window, Pipeline{}(x_window,
...@@ -305,6 +339,7 @@ struct Layernorm2dFwd ...@@ -305,6 +339,7 @@ struct Layernorm2dFwd
sy_window, sy_window,
mean_window, mean_window,
inv_std_window, inv_std_window,
y_scale_window,
static_cast<const ComputeDataType>(kargs.epsilon), static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n, kargs.n,
smem, smem,
......
...@@ -59,6 +59,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -59,6 +59,7 @@ struct Layernorm2dFwdPipelineOnePass
typename SYWindow, typename SYWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow, typename InvStdWindow,
typename YScaleWindow,
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const SXWindow& sx_window_, const SXWindow& sx_window_,
...@@ -68,6 +69,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -68,6 +69,7 @@ struct Layernorm2dFwdPipelineOnePass
const SYWindow& sy_window_, const SYWindow& sy_window_,
MeanWindow& mean_window, MeanWindow& mean_window,
InvStdWindow& inv_std_window, InvStdWindow& inv_std_window,
YScaleWindow& y_scale_window,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem, void* smem,
...@@ -143,7 +145,12 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -143,7 +145,12 @@ struct Layernorm2dFwdPipelineOnePass
ln(idx) = ln_; ln(idx) = ln_;
}); });
Epilogue{}(y_window_, ln); if constexpr(kFusedSweep == Layernorm2dFusedSweepEnum::DYNAMIC_QUANT)
{
Epilogue{}(y_window_, y_scale_window, ln);
}
else
Epilogue{}(y_window_, ln);
} }
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -14,6 +14,7 @@ template <typename XDataType_, ...@@ -14,6 +14,7 @@ template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename MeanDataType_, typename MeanDataType_,
typename InvStdDataType_, typename InvStdDataType_,
typename YScaleDataType_,
typename BlockShape_, typename BlockShape_,
typename Traits_> typename Traits_>
struct Layernorm2dFwdPipelineProblem struct Layernorm2dFwdPipelineProblem
...@@ -25,6 +26,7 @@ struct Layernorm2dFwdPipelineProblem ...@@ -25,6 +26,7 @@ struct Layernorm2dFwdPipelineProblem
using YDataType = remove_cvref_t<YDataType_>; using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>; using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>; using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
......
...@@ -58,6 +58,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -58,6 +58,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename SYWindow, typename SYWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow, typename InvStdWindow,
typename YScaleWindow,
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const SXWindow& sx_window_, const SXWindow& sx_window_,
...@@ -67,6 +68,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -67,6 +68,7 @@ struct Layernorm2dFwdPipelineTwoPass
const SYWindow& sy_window_, const SYWindow& sy_window_,
MeanWindow& mean_window, MeanWindow& mean_window,
InvStdWindow& inv_std_window, InvStdWindow& inv_std_window,
YScaleWindow& /*y_scale_window*/,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem, void* smem,
...@@ -189,6 +191,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -189,6 +191,7 @@ struct Layernorm2dFwdPipelineTwoPass
ln(idx) = ln_; ln(idx) = ln_;
}); });
static_assert(kFusedSweep != Layernorm2dFusedSweepEnum::DYNAMIC_QUANT);
Epilogue{}(y_window, ln); Epilogue{}(y_window, ln);
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
......
...@@ -26,15 +26,13 @@ template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD> ...@@ -26,15 +26,13 @@ template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD>
enum class Layernorm2dFusedSweepEnum enum class Layernorm2dFusedSweepEnum
{ {
NO_SWEEP = 0, NO_SWEEP = 0,
RENORM = 1, DYNAMIC_QUANT = 1,
DYNAMIC_QUANT = 2,
}; };
// clang-format off // clang-format off
template<Layernorm2dFusedSweepEnum E> struct Layernorm2dFusedSweepEnumName; template<Layernorm2dFusedSweepEnum E> struct Layernorm2dFusedSweepEnumName;
template<> struct Layernorm2dFusedSweepEnumName<Layernorm2dFusedSweepEnum::NO_SWEEP> { static constexpr const char * name = "no"; }; template<> struct Layernorm2dFusedSweepEnumName<Layernorm2dFusedSweepEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dFusedSweepEnumName<Layernorm2dFusedSweepEnum::RENORM> { static constexpr const char * name = "renorm"; }; template<> struct Layernorm2dFusedSweepEnumName<Layernorm2dFusedSweepEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dquant"; };
template<> struct Layernorm2dFusedSweepEnumName<Layernorm2dFusedSweepEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dequant"; };
// clang-format on // clang-format on
template <bool kPadN_, template <bool kPadN_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment