Unverified Commit 4bc61041 authored by feli's avatar feli Committed by GitHub
Browse files

Ck tile/layernorm: implement naive reduce, opt performance (#1784)



* add no welford

* enable output raw

* raw of int8

* fix build

* fix smoke test err

* [ck_tile]layernorm: fix welford ok, set int8 and bf16 small N as default and others open by generate

* [cktile]layernorm, fix err commit files and remove uselss

* fix quant 8192 err & change norm_reduce class and file name

---------
Co-authored-by: default avatarcoderfeli <coderfeli@163.com>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
parent 17e8efb5
......@@ -58,6 +58,7 @@ template <typename XDataType_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0>
......@@ -120,6 +121,7 @@ struct layernorm2d_fwd_traits_
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
......@@ -137,6 +139,7 @@ template <typename XDataType_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_,
int kFusedAdd_,
int kFusedQuant_>
......@@ -152,6 +155,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
kPadN_,
kSaveMeanInvStd_,
kFastFDiv_,
kWelford_,
kTwoPass_,
kFusedAdd_,
kFusedQuant_>;
......@@ -184,6 +188,7 @@ float layernorm2d_fwd_(const S& s, A a)
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kFastFDiv,
Traits_::kWelford,
Traits_::kTwoPass,
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
......@@ -204,12 +209,13 @@ float layernorm2d_fwd_(const S& s, A a)
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, true>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
static constexpr bool UseRawStore = sizeof(YDataType) == 4;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, XScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, UseRawStore, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
......@@ -274,7 +280,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
#include "layernorm2d_fwd_api_common.hpp"
// clang-format off
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p add sweep
{F_instance_def}
// clang-format on
......@@ -362,6 +368,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
F_kPadN : bool
F_kSaveMeanInvStd_ : bool
F_kFastFDiv_ : bool
F_kWelford_ : bool
F_kTwoPass_ : bool
F_kFusedAdd : int
F_kFusedQuant : int
......@@ -369,7 +376,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@property
def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
return t_
......@@ -422,11 +429,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
def name_common_header(self) -> str:
return 'layernorm2d_fwd_api_common'
@property
def content_api(self) -> str:
def content_api(self, args) -> str:
# 1 sort based on dtype
t_dtype_dict = dict()
blobs = self.get_blobs()
blobs = self.get_blobs(args)
for blob in blobs:
if blob.F_DataTypePair not in t_dtype_dict:
t_dtype_dict[blob.F_DataTypePair] = {}
......@@ -462,8 +468,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
F_VEC_COND = _cond, F_instance_func=ins.call_name)
#inner_str = inner_str + vec_str
n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else ''
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
n_cnd = f'(a.n <= {n_})' if isinstance(n_, int) else ''
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t), not isinstance(n_, int)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
prec_i, prec_o = dtype_.split(',')
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
......@@ -474,7 +480,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
def content_common_header(self) -> str:
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
def get_blobs(self):
def get_blobs(self, args):
h_traits = layernorm_fwd_codegen.h_traits
h_instance = layernorm_fwd_codegen.h_instance
......@@ -484,60 +490,61 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
scale_list = [('fp32,fp32')]
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
types_8bit = ('int8', 'fp8')
types_16bit = ('int16', 'fp16', 'bf16')
#fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
fused_add_list = [0, 1]
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv fdiv 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)],
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)],
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]}
# rm rn tm tn vn pd mv fdiv welford 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0)],
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0)],
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0)]}
total_blob = list()
for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key]
......@@ -558,16 +565,27 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_.F_YScaleDataType = scale_x
h_.F_kFusedAdd = fused_add
h_.F_kFusedQuant = fused_quant
# disable welford update for 8bit and 16 bit smallN
if not h_.F_kTwoPass_:
#disable 16 bit when set args disable_16b_welford
if args.disable_16b_welford and prec_i in types_16bit:
h_.F_kWelford_ = False
#disable 8bit by default
elif prec_i in types_8bit or prec_o in types_8bit:
h_.F_kWelford_ = False
#disable 16bit small N
elif prec_i in types_16bit and hs_key == '64':
h_.F_kWelford_ = False
current_hs.append(h_) # + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str = 'big' if hs_key == 'big' else current_n
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs))
return total_blob
def list_blobs(self) -> None:
def list_blobs(self, args) -> None:
w_p = Path(self.working_path)
list_p = w_p / 'layernorm2d_fwd_blobs.txt'
blobs = self.get_blobs()
blobs = self.get_blobs(args)
with list_p.open('w') as list_f:
# api related file
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
......@@ -576,11 +594,12 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
for b in blobs:
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
def gen_blobs(self) -> None:
def gen_blobs(self, args) -> None:
w_p = Path(self.working_path)
(w_p / (self.name_api + ".cpp")).write_text(self.content_api)
w_str = self.content_api(args)
(w_p / (self.name_api + ".cpp")).write_text(w_str)
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
blobs = self.get_blobs()
blobs = self.get_blobs(args)
for b in blobs:
(w_p / (b.name + ".cpp")).write_text(b.content)
......@@ -588,14 +607,14 @@ def list_blobs(args):
api_list = args.api.split(',')
for api in api_list:
if api == 'fwd':
layernorm_fwd_codegen(args.working_path, args.filter).list_blobs()
layernorm_fwd_codegen(args.working_path, args.filter).list_blobs(args)
def gen_blobs(args):
api_list = args.api.split(',')
for api in api_list:
if api == 'fwd':
layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs()
layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
......@@ -663,6 +682,13 @@ if __name__ == "__main__":
help="codegen receipt."
)
parser.add_argument(
"--disable_16b_welford",
default=False,
required=False,
help="enable/disable welford for 16bit datatype n > 64"
)
args = parser.parse_args()
# print(f'{args.list_blobs}-{args.gen_blobs}')
......
......@@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=9120
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
......
......@@ -4,8 +4,8 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/block/block_welford.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
namespace ck_tile {
......@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford()
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduce()
{
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
return BlockWelford<P_>{};
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockNormReduce<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync()
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceSync()
{
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockWelfordSync<P_>{};
return BlockNormReduceSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync()
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceCrossWarpSync()
{
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockWelfordCrossWarpSync<P_>{};
return BlockNormReduceCrossWarpSync<P_>{};
}
template <typename Problem>
......@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{
if constexpr(Problem::kNeedCrossWarpSync)
{
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
using block_welford = BlockWelford<P_>;
using block_welford = BlockNormReduce<P_>;
using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>()));
using mean_var_block_tile =
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
return GetBlockWelfordCrossWarpSync<Problem>()
return GetBlockNormReduceCrossWarpSync<Problem>()
.template GetSmemSize<mean_var_block_tile>();
}
else
......
......@@ -37,6 +37,7 @@ struct Layernorm2dFwdPipelineOnePass
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
......@@ -95,11 +96,16 @@ struct Layernorm2dFwdPipelineOnePass
int cur_count = 0;
int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
auto block_welford = Policy::template GetBlockWelford<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
auto block_welford_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(x));
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
clear_tile(mean);
clear_tile(var);
// load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window);
......@@ -117,12 +123,21 @@ struct Layernorm2dFwdPipelineOnePass
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
}
// compute welford each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(acc, cur_count, max_count);
block_welford_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
// compute reduce each-thread->cross-lane->cross-warp
block_norm_reduce(acc, mean, var, cur_count, max_count);
block_norm_reduce_sync(mean, var, cur_count);
block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
if(kWelford)
{
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
}
else
{
sweep_tile(mean, [&](auto idx) {
mean(idx) = mean(idx) / type_convert<MeanDataType>(row_size);
var(idx) = var(idx) / type_convert<MeanDataType>(row_size) - mean(idx) * mean(idx);
});
}
// compute inv-std
auto inv_std = tile_elementwise_in(
[&](const auto& v_) {
......@@ -153,8 +168,7 @@ struct Layernorm2dFwdPipelineOnePass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
ln(idx) = ln_;
ln(idx) = ln_;
});
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
......
......@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineTwoPass
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
......@@ -77,6 +78,7 @@ struct Layernorm2dFwdPipelineTwoPass
void* smem,
Epilogue) const
{
static_assert(kWelford == true, "2 pass only supports welford merge");
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window(
......@@ -102,14 +104,14 @@ struct Layernorm2dFwdPipelineTwoPass
int max_count =
(num_n_tile_iteration - 1) * count_per_iter +
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
auto block_welford = Policy::template GetBlockWelford<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
auto block_welford_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>();
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
......@@ -133,11 +135,11 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(y_residual_window, {0, Block_N});
}
}
block_welford(acc, mean, var, cur_count, max_count);
block_norm_reduce(acc, mean, var, cur_count, max_count);
}
block_welford_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem);
block_norm_reduce_sync(mean, var, cur_count);
block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
// compute inv-std
......
......@@ -40,6 +40,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
template <bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_,
Layernorm2dFusedAddEnum kFusedAdd_,
Layernorm2dFusedQuantEnum kFusedQuant_>
......@@ -48,6 +49,7 @@ struct Layernorm2dFwdTraits
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
......
......@@ -3,8 +3,8 @@
#pragma once
#include "ck_tile/ops/welford/block/block_welford.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
......@@ -4,22 +4,23 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = void>
struct BlockWelford
struct BlockNormReduce
{
using Problem = remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
using ComputeDataType = typename Problem::ComputeDataType;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
CK_TILE_DEVICE constexpr BlockWelford() {}
CK_TILE_DEVICE constexpr BlockNormReduce() {}
// [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN welford will have different
// max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different
// calculation of max_count_
// -> use block_welford_calculate_max_count to compute
template <typename XDistributedTensor_,
......@@ -40,18 +41,24 @@ struct BlockWelford
if(cur_count_ < max_count_)
{
++cur_count_;
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
welford_update(mean_tensor(out_dstr_idx),
var_tensor(out_dstr_idx),
x,
cur_count_,
constant<kFastFDiv>{});
if(kWelford)
{
welford_update(mean_tensor(out_dstr_idx),
var_tensor(out_dstr_idx),
x,
cur_count_,
constant<kFastFDiv>{});
}
else
{
mean_tensor(out_dstr_idx) += x;
var_tensor(out_dstr_idx) += x * x;
}
});
}
});
......@@ -91,10 +98,11 @@ struct BlockWelford
};
template <typename Problem_, typename Policy_ = void>
struct BlockWelfordSync
struct BlockNormReduceSync
{
using Problem = remove_cvref_t<Problem_>;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void
......@@ -152,36 +160,48 @@ struct BlockWelfordSync
(number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
// welford merge
welford_merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
if(kWelford)
{
const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
// norm_reduce merge
welford_merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
}
else
{
v_local_mean += v_remote_mean;
v_local_var += v_remote_var;
}
});
}
});
mean_tensor.get_thread_buffer()(i) = v_local_mean;
var_tensor.get_thread_buffer()(i) = v_local_var;
count = v_local_count;
if(kWelford)
{
count = v_local_count;
}
});
}
};
template <typename Problem_, typename Policy_ = void>
struct BlockWelfordCrossWarpSync
struct BlockNormReduceCrossWarpSync
{
using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
using smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>;
template <typename MeanDistributedTensor_>
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
......@@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
// Note: we always pack everything into fp32x4
fp32x4_t* smem_ptr = reinterpret_cast<fp32x4_t*>(smem);
smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem);
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
......@@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync
if(lane_id == 0)
{
static_for<0, thread_buf_size, 1>{}([&](auto i) {
fp32x4_t local_scratch_;
smem_dtype local_scratch_;
local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
local_scratch_[2] = bit_cast<float>(count);
if(kWelford)
{
local_scratch_[2] = bit_cast<float>(count);
}
smem_ptr[smem_offset + i * num_warps] = local_scratch_;
});
}
......@@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync
// load from smem. here we let everythread to do compute :)
index_t local_warp_id = warp_id / num_reduce_warps;
index_t local_smem_os = local_warp_id * num_reduce_warps;
fp32x4_t all_scratch[thread_buf_size * num_reduce_warps];
smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
all_scratch[i_0 * num_reduce_warps + i_1] =
......@@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
// TODO: use descriptor for this
auto v_local = all_scratch[i_0 * num_reduce_warps];
auto v_local_mean = bit_cast<DataType>(v_local[0]);
auto v_local_var = bit_cast<DataType>(v_local[1]);
auto v_local_count = bit_cast<int>(v_local[2]);
auto v_local = all_scratch[i_0 * num_reduce_warps];
auto v_local_mean = bit_cast<DataType>(v_local[0]);
auto v_local_var = bit_cast<DataType>(v_local[1]);
int v_local_count = kWelford ? bit_cast<int>(v_local[2]) : 0;
// further reduce mean/var
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{};
const fp32x4_t v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
const auto v_remote_var = bit_cast<DataType>(v_remote[1]);
const auto v_remote_count = bit_cast<int>(v_remote[2]);
welford_merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
if(kWelford)
{
const auto v_remote_count = bit_cast<int>(v_remote[2]);
welford_merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
}
else
{
v_local_mean += v_remote_mean;
v_local_var += v_remote_var;
}
});
mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
var_tensor.get_thread_buffer()(i_0) = v_local_var;
count = v_local_count;
if(kWelford)
count = v_local_count;
});
}
};
......
......@@ -7,13 +7,18 @@
namespace ck_tile {
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_, bool kFastFDiv_>
struct BlockWelfordProblem
template <typename XDataType_,
typename ComputeDataType_,
typename BlockShape_,
bool kFastFDiv_,
bool kWelford_>
struct BlockNormReduceProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment