Commit 018e939f authored by Jiming Ruan's avatar Jiming Ruan
Browse files

Modify tests and bug fix

parent 54617a85
...@@ -61,6 +61,8 @@ template <typename XDataType_, ...@@ -61,6 +61,8 @@ template <typename XDataType_,
ck_tile::index_t Vector_N_, // vector size along N ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_, bool kPadN_,
bool kSaveInvRms_, bool kSaveInvRms_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_, bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0, ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0> ck_tile::index_t kFusedQuant_ = 0>
...@@ -122,6 +124,8 @@ struct rmsnorm2d_fwd_traits_ ...@@ -122,6 +124,8 @@ struct rmsnorm2d_fwd_traits_
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_; static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_; static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
...@@ -138,6 +142,8 @@ template <typename XDataType_, ...@@ -138,6 +142,8 @@ template <typename XDataType_,
ck_tile::index_t Vector_N_, // vector size along N ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_, bool kPadN_,
bool kSaveInvRms_, bool kSaveInvRms_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_, bool kTwoPass_,
int kFusedAdd_, int kFusedAdd_,
int kFusedQuant_> int kFusedQuant_>
...@@ -152,6 +158,8 @@ using traits_ = rmsnorm2d_fwd_traits_<XDataType_, ...@@ -152,6 +158,8 @@ using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
Vector_N_, Vector_N_,
kPadN_, kPadN_,
kSaveInvRms_, kSaveInvRms_,
kFastFDiv_,
kWelford_,
kTwoPass_, kTwoPass_,
kFusedAdd_, kFusedAdd_,
kFusedQuant_>; kFusedQuant_>;
...@@ -185,6 +193,8 @@ float rmsnorm2d_fwd_(const S& s, A a) ...@@ -185,6 +193,8 @@ float rmsnorm2d_fwd_(const S& s, A a)
using PipelineTraits = using PipelineTraits =
ck_tile::Rmsnorm2dFwdTraits<Traits_::kPadN, ck_tile::Rmsnorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveInvRms, Traits_::kSaveInvRms,
Traits_::kFastFDiv,
Traits_::kWelford,
Traits_::kTwoPass, Traits_::kTwoPass,
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd), static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>; static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
...@@ -204,12 +214,14 @@ float rmsnorm2d_fwd_(const S& s, A a) ...@@ -204,12 +214,14 @@ float rmsnorm2d_fwd_(const S& s, A a)
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>; using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>; using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>; static constexpr bool UseRawStore = sizeof(YDataType) == 4;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, UseRawStore>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>; using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape, using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>; ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, UseRawStore, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>; using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
...@@ -262,24 +274,21 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, ...@@ -262,24 +274,21 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
#include "rmsnorm2d_fwd_api_common.hpp" #include "rmsnorm2d_fwd_api_common.hpp"
// clang-format off // clang-format off
// rm rn tm tn vn pd rms 2p // rm rn tm tn vn pd rms welford 2p
{F_instance_def} {F_instance_def}
// clang-format on // clang-format on
""" """
API_PER_DTYPE = """ API_PER_DTYPE=""" {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{
{F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{
{F_per_n_case} {F_per_n_case}
}} }}
""" """
API_PER_N_CASE = """ API_PER_N_CASE=""" {F_if} {F_N_COND} {{
{F_if} {F_N_COND} {{
{F_inner_dispatch} {F_inner_dispatch}
}} }}
""" """
API_INNER_CASE = """ API_INNER_CASE=""" {F_if} {F_VEC_COND}
{F_if} {F_VEC_COND}
r={F_instance_func}(s, a); r={F_instance_func}(s, a);
""" """
...@@ -362,6 +371,8 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, ...@@ -362,6 +371,8 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
F_Vector_N : int F_Vector_N : int
F_kPadN : bool F_kPadN : bool
F_kSaveInvRms : bool F_kSaveInvRms : bool
F_kFastFDiv_ : bool
F_kWelford_ : bool
F_kTwoPass : bool F_kTwoPass : bool
F_kFusedAdd : int F_kFusedAdd : int
F_kFusedQuant : int F_kFusedQuant : int
...@@ -369,7 +380,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, ...@@ -369,7 +380,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
@property @property
def trait_name(self) ->str: def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}' t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
return t_ return t_
...@@ -422,11 +433,10 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, ...@@ -422,11 +433,10 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
def name_common_header(self) -> str: def name_common_header(self) -> str:
return 'rmsnorm2d_fwd_api_common' return 'rmsnorm2d_fwd_api_common'
@property def content_api(self, args) -> str:
def content_api(self) -> str:
# 1 sort based on dtype # 1 sort based on dtype
t_dtype_dict = dict() t_dtype_dict = dict()
blobs = self.get_blobs() blobs = self.get_blobs(args)
for blob in blobs: for blob in blobs:
if blob.F_DataTypePair not in t_dtype_dict: if blob.F_DataTypePair not in t_dtype_dict:
t_dtype_dict[blob.F_DataTypePair] = {} t_dtype_dict[blob.F_DataTypePair] = {}
...@@ -462,8 +472,8 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, ...@@ -462,8 +472,8 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
F_VEC_COND = _cond, F_instance_func=ins.call_name) F_VEC_COND = _cond, F_instance_func=ins.call_name)
#inner_str = inner_str + vec_str #inner_str = inner_str + vec_str
n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else '' n_cnd = f'(a.n <= {n_})' if isinstance(n_, int) else ''
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t), not isinstance(n_, int)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
prec_i, prec_o = dtype_.split(',') prec_i, prec_o = dtype_.split(',')
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
...@@ -474,7 +484,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, ...@@ -474,7 +484,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
def content_common_header(self) -> str: def content_common_header(self) -> str:
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE) return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
def get_blobs(self): def get_blobs(self, args):
h_traits = rmsnorm_fwd_codegen.h_traits h_traits = rmsnorm_fwd_codegen.h_traits
h_instance = rmsnorm_fwd_codegen.h_instance h_instance = rmsnorm_fwd_codegen.h_instance
...@@ -485,60 +495,61 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, ...@@ -485,60 +495,61 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
dtype_list = [('fp16,fp16'), ('bf16,bf16'), dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8'), ('fp16,int8'), ('bf16,int8'),
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 out ('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 out
types_8bit = ('int8', 'fp8')
types_16bit = ('int16', 'fp16', 'bf16')
#fused_add_list = [0, 1, 2] #fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
fused_add_list = [0, 1] fused_add_list = [0, 1]
fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
# rm rn tm tn vn pd mv 2p add sweep # rm rn tm tn vn pd mv fdiv welford 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, False, 0, 0), h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0)],
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, False, 0, 0), '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0), '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0), '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0)],
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0), '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0), '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0), '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0), '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0), '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0), '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0), '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0), '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0)]}
total_blob = list() total_blob = list()
for hs_key in h_trait_dict: for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key] hs = h_trait_dict[hs_key]
...@@ -559,16 +570,27 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, ...@@ -559,16 +570,27 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
h_.F_YScaleDataType = scale_y h_.F_YScaleDataType = scale_y
h_.F_kFusedAdd = fused_add h_.F_kFusedAdd = fused_add
h_.F_kFusedQuant = fused_quant h_.F_kFusedQuant = fused_quant
# disable welford update for 8bit and 16 bit smallN
if not h_.F_kTwoPass:
#disable 16 bit when set args disable_16b_welford
if args.disable_16b_welford and prec_i in types_16bit:
h_.F_kWelford_ = False
#disable 8bit by default
elif prec_i in types_8bit or prec_o in types_8bit:
h_.F_kWelford_ = False
#disable 16bit small N
elif prec_i in types_16bit and hs_key == '64':
h_.F_kWelford_ = False
current_hs.append(h_) # + "\n" current_hs.append(h_) # + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str = 'big' if hs_key == 'big' else current_n current_n_str = 'big' if hs_key == 'big' else current_n
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs))
return total_blob return total_blob
def list_blobs(self) -> None: def list_blobs(self, args) -> None:
w_p = Path(self.working_path) w_p = Path(self.working_path)
list_p = w_p / 'rmsnorm2d_fwd_blobs.txt' list_p = w_p / 'rmsnorm2d_fwd_blobs.txt'
blobs = self.get_blobs() blobs = self.get_blobs(args)
with list_p.open('w') as list_f: with list_p.open('w') as list_f:
# api related file # api related file
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
...@@ -577,11 +599,12 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, ...@@ -577,11 +599,12 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
for b in blobs: for b in blobs:
list_f.write(str(w_p / (b.name + ".cpp")) + "\n") list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
def gen_blobs(self) -> None: def gen_blobs(self, args) -> None:
w_p = Path(self.working_path) w_p = Path(self.working_path)
(w_p / (self.name_api + ".cpp")).write_text(self.content_api) w_str = self.content_api(args)
(w_p / (self.name_api + ".cpp")).write_text(w_str)
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
blobs = self.get_blobs() blobs = self.get_blobs(args)
for b in blobs: for b in blobs:
(w_p / (b.name + ".cpp")).write_text(b.content) (w_p / (b.name + ".cpp")).write_text(b.content)
...@@ -590,14 +613,14 @@ def list_blobs(args): ...@@ -590,14 +613,14 @@ def list_blobs(args):
api_list = args.api.split(',') api_list = args.api.split(',')
for api in api_list: for api in api_list:
if api == 'fwd': if api == 'fwd':
rmsnorm_fwd_codegen(args.working_path, args.filter).list_blobs() rmsnorm_fwd_codegen(args.working_path, args.filter).list_blobs(args)
def gen_blobs(args): def gen_blobs(args):
api_list = args.api.split(',') api_list = args.api.split(',')
for api in api_list: for api in api_list:
if api == 'fwd': if api == 'fwd':
rmsnorm_fwd_codegen(args.working_path, args.filter).gen_blobs() rmsnorm_fwd_codegen(args.working_path, args.filter).gen_blobs(args)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -666,6 +689,13 @@ if __name__ == "__main__": ...@@ -666,6 +689,13 @@ if __name__ == "__main__":
help="codegen receipt." help="codegen receipt."
) )
parser.add_argument(
"--disable_16b_welford",
default=False,
required=False,
help="enable/disable welford for 16bit datatype n > 64"
)
args = parser.parse_args() args = parser.parse_args()
# print(f'{args.list_blobs}-{args.gen_blobs}') # print(f'{args.list_blobs}-{args.gen_blobs}')
......
...@@ -200,6 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -200,6 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = rmsnorm2d_fwd( float ave_time = rmsnorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0)
{
std::cout << " not supported!" << std::endl << std::flush;
return false;
}
std::size_t num_byte = std::size_t num_byte =
sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n; sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n;
num_byte += SaveRms ? sizeof(InvRmsDataType) * m * n : 0; num_byte += SaveRms ? sizeof(InvRmsDataType) * m * n : 0;
......
...@@ -120,6 +120,13 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -120,6 +120,13 @@ struct Rmsnorm2dFwdPipelineOnePass
block_norm_reduce_sync(square_mean, cur_count); block_norm_reduce_sync(square_mean, cur_count);
block_norm_reduce_cross_warp_sync(square_mean, cur_count, smem); block_norm_reduce_cross_warp_sync(square_mean, cur_count, smem);
if constexpr(!kWelford)
{
sweep_tile(square_mean, [&](auto idx) {
square_mean(idx) = square_mean(idx) / type_convert<ComputeDataType>(row_size);
});
}
// compute inv-rms // compute inv-rms
auto inv_rms = tile_elementwise_in( auto inv_rms = tile_elementwise_in(
[&](const auto& v_) { [&](const auto& v_) {
......
...@@ -70,6 +70,8 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -70,6 +70,8 @@ struct Rmsnorm2dFwdPipelineTwoPass
void* smem, void* smem,
Epilogue) const Epilogue) const
{ {
static_assert(kWelford == true, "2 pass only supports welford merge");
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window( auto gamma_window = make_tile_window(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment