Commit abd2755a authored by ThomasNing's avatar ThomasNing
Browse files

Merge branch 'develop' into moe_cross_reduce

parents b74918bc 888317e6
......@@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
......@@ -355,4 +353,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
\ No newline at end of file
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
......@@ -96,9 +96,7 @@ using fmha_epilogue =
{F_spad}, {F_dvpad}>>;
using fmha_kernel =
ck_tile::FmhaFwdSplitKVKernel<ck_tile::FmhaFwdSplitKVTilePartitioner<fmha_shape>,
fmha_pipeline,
fmha_epilogue>;
ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
......@@ -176,11 +174,7 @@ using fmha_epilogue =
false, false>>;
using fmha_kernel =
ck_tile::FmhaFwdSplitKVCombineKernel<
ck_tile::FmhaFwdSplitKVCombineTilePartitioner<
fmha_pipeline_problem::kM0, fmha_pipeline_problem::kN1>,
fmha_pipeline,
fmha_epilogue>;
ck_tile::FmhaFwdSplitKVCombineKernel<fmha_pipeline, fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
......@@ -261,7 +255,7 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F
static_assert({F_bn1} % 32 == 0);
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
......@@ -614,7 +608,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
}
......
......@@ -1131,15 +1131,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
// NOTE: use gpu to do validation
ck_tile::naive_attention_fwd_traits naive_t;
naive_t.q_type = data_type;
naive_t.k_type = data_type;
naive_t.v_type = data_type;
naive_t.o_type = data_type;
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
naive_t.variation = 0; // TODO?
naive_t.q_type = data_type;
naive_t.k_type = data_type;
naive_t.v_type = data_type;
naive_t.o_type = data_type;
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
naive_t.variation = 0; // TODO?
naive_t.quant_algo = 0;
ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes());
......
......@@ -400,8 +400,18 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
}();
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
if constexpr(FmhaKernel::kIsGroupMode)
{
dim3 grids = FmhaKernel::GridSize(
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
return ck_tile::make_tuple(kargs, grids);
}
else
{
dim3 grids =
FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
return ck_tile::make_tuple(kargs, grids);
}
}
template <typename Kernel>
......
......@@ -58,6 +58,7 @@ template <typename XDataType_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0>
......@@ -120,6 +121,7 @@ struct layernorm2d_fwd_traits_
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
......@@ -137,6 +139,7 @@ template <typename XDataType_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_,
int kFusedAdd_,
int kFusedQuant_>
......@@ -152,6 +155,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
kPadN_,
kSaveMeanInvStd_,
kFastFDiv_,
kWelford_,
kTwoPass_,
kFusedAdd_,
kFusedQuant_>;
......@@ -184,6 +188,7 @@ float layernorm2d_fwd_(const S& s, A a)
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kFastFDiv,
Traits_::kWelford,
Traits_::kTwoPass,
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
......@@ -204,12 +209,13 @@ float layernorm2d_fwd_(const S& s, A a)
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, true>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
static constexpr bool UseRawStore = sizeof(YDataType) == 4;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, XScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, UseRawStore, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
......@@ -274,7 +280,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
#include "layernorm2d_fwd_api_common.hpp"
// clang-format off
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p add sweep
{F_instance_def}
// clang-format on
......@@ -362,6 +368,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
F_kPadN : bool
F_kSaveMeanInvStd_ : bool
F_kFastFDiv_ : bool
F_kWelford_ : bool
F_kTwoPass_ : bool
F_kFusedAdd : int
F_kFusedQuant : int
......@@ -369,7 +376,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@property
def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
return t_
......@@ -422,11 +429,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
def name_common_header(self) -> str:
return 'layernorm2d_fwd_api_common'
@property
def content_api(self) -> str:
def content_api(self, args) -> str:
# 1 sort based on dtype
t_dtype_dict = dict()
blobs = self.get_blobs()
blobs = self.get_blobs(args)
for blob in blobs:
if blob.F_DataTypePair not in t_dtype_dict:
t_dtype_dict[blob.F_DataTypePair] = {}
......@@ -462,8 +468,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
F_VEC_COND = _cond, F_instance_func=ins.call_name)
#inner_str = inner_str + vec_str
n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else ''
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
n_cnd = f'(a.n <= {n_})' if isinstance(n_, int) else ''
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t), not isinstance(n_, int)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
prec_i, prec_o = dtype_.split(',')
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
......@@ -474,7 +480,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
def content_common_header(self) -> str:
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
def get_blobs(self):
def get_blobs(self, args):
h_traits = layernorm_fwd_codegen.h_traits
h_instance = layernorm_fwd_codegen.h_instance
......@@ -484,60 +490,61 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
scale_list = [('fp32,fp32')]
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
types_8bit = ('int8', 'fp8')
types_16bit = ('int16', 'fp16', 'bf16')
#fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
fused_add_list = [0, 1]
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv fdiv 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)],
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)],
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]}
# rm rn tm tn vn pd mv fdiv welford 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0)],
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0)],
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0)]}
total_blob = list()
for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key]
......@@ -558,16 +565,27 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_.F_YScaleDataType = scale_x
h_.F_kFusedAdd = fused_add
h_.F_kFusedQuant = fused_quant
# disable welford update for 8bit and 16 bit smallN
if not h_.F_kTwoPass_:
#disable 16 bit when set args disable_16b_welford
if args.disable_16b_welford and prec_i in types_16bit:
h_.F_kWelford_ = False
#disable 8bit by default
elif prec_i in types_8bit or prec_o in types_8bit:
h_.F_kWelford_ = False
#disable 16bit small N
elif prec_i in types_16bit and hs_key == '64':
h_.F_kWelford_ = False
current_hs.append(h_) # + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str = 'big' if hs_key == 'big' else current_n
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs))
return total_blob
def list_blobs(self) -> None:
def list_blobs(self, args) -> None:
w_p = Path(self.working_path)
list_p = w_p / 'layernorm2d_fwd_blobs.txt'
blobs = self.get_blobs()
blobs = self.get_blobs(args)
with list_p.open('w') as list_f:
# api related file
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
......@@ -576,11 +594,12 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
for b in blobs:
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
def gen_blobs(self) -> None:
def gen_blobs(self, args) -> None:
w_p = Path(self.working_path)
(w_p / (self.name_api + ".cpp")).write_text(self.content_api)
w_str = self.content_api(args)
(w_p / (self.name_api + ".cpp")).write_text(w_str)
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
blobs = self.get_blobs()
blobs = self.get_blobs(args)
for b in blobs:
(w_p / (b.name + ".cpp")).write_text(b.content)
......@@ -588,14 +607,14 @@ def list_blobs(args):
api_list = args.api.split(',')
for api in api_list:
if api == 'fwd':
layernorm_fwd_codegen(args.working_path, args.filter).list_blobs()
layernorm_fwd_codegen(args.working_path, args.filter).list_blobs(args)
def gen_blobs(args):
api_list = args.api.split(',')
for api in api_list:
if api == 'fwd':
layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs()
layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
......@@ -663,6 +682,13 @@ if __name__ == "__main__":
help="codegen receipt."
)
parser.add_argument(
"--disable_16b_welford",
default=False,
required=False,
help="enable/disable welford for 16bit datatype n > 64"
)
args = parser.parse_args()
# print(f'{args.list_blobs}-{args.gen_blobs}')
......
......@@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=9120
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
......
......@@ -54,8 +54,7 @@ using CDataType = Types::CDataType;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size")
.insert("m", "3840", "m dimension")
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
......@@ -68,7 +67,8 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
......
......@@ -64,9 +64,9 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t batch_size = arg_parser.get_int("b");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
using namespace ck_tile::literals;
......@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc,
stride_A,
stride_B,
stride_C,
batch_size,
kbatch,
n_warmup,
n_repeat);
......
......@@ -22,7 +22,7 @@
#endif
template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
......@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
#endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
......@@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
has_hot_loop_v,
tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
......
......@@ -3,18 +3,42 @@
#include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
using ms_problem = \
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{
if(t.weight_type == "fp32" && t.index_type == "int32")
......@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
case(6): {
MOE_SORTING_DISPATCH(6);
}
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): {
MOE_SORTING_DISPATCH(8);
}
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): {
MOE_SORTING_DISPATCH(10);
}
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: {
MOE_SORTING_DISPATCH(4);
}
......
......@@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19
$EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13
\ No newline at end of file
$EXE -t=333 -e=99 -k=13
$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
......@@ -3,18 +3,42 @@
#include "fused_moesorting.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
using ms_problem = \
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
{
if(t.weight_type == "fp32" && t.index_type == "int32")
......@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
case(6): {
MOE_SORTING_DISPATCH(6);
}
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): {
MOE_SORTING_DISPATCH(8);
}
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): {
MOE_SORTING_DISPATCH(10);
}
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: {
MOE_SORTING_DISPATCH(4);
}
......
......@@ -70,20 +70,25 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
......
......@@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
......
......@@ -17,6 +17,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t batch_stride_B,
ck_tile::index_t batch_stride_C,
ck_tile::index_t batch_count,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
......@@ -24,6 +25,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
......@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
......@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_B,
batch_stride_C,
batch_count,
kbatch,
n_warmup,
n_repeat);
......
......@@ -115,8 +115,8 @@
#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@
#endif
#ifndef DCK_USE_OCP_FP8
#cmakedefine DCK_USE_OCP_FP8 @DCK_USE_OCP_FP8@
#ifndef CK_USE_OCP_FP8
#cmakedefine CK_USE_OCP_FP8 @CK_USE_OCP_FP8@
#endif
#ifndef CK_USE_FNUZ_FP8
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
else
os << delim;
if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck::bf8_t>)
using RangeType = ck::remove_cvref_t<decltype(v)>;
if constexpr(std::is_same_v<RangeType, ck::f8_t> || std::is_same_v<RangeType, ck::bf8_t> ||
std::is_same_v<RangeType, ck::bhalf_t>)
{
os << ck::type_convert<float>(v);
}
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t>)
{
const auto packed_floats = ck::type_convert<ck::float2_t>(v);
const ck::vector_type<float, 2> vector_of_floats{packed_floats};
os << vector_of_floats.template AsType<float>()[ck::Number<0>{}] << delim
<< vector_of_floats.template AsType<float>()[ck::Number<1>{}];
}
else
{
os << static_cast<T>(v);
......@@ -266,18 +275,18 @@ struct Tensor
using Data = std::vector<T>;
template <typename X>
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(GetElementSpaceSize())
{
}
template <typename X, typename Y>
Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize())
: mDesc(lens, strides), mData(GetElementSpaceSize())
{
}
template <typename Lengths>
Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize())
{
}
......@@ -287,7 +296,7 @@ struct Tensor
{
}
Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {}
template <typename OutT>
Tensor<OutT> CopyAsType() const
......@@ -322,7 +331,17 @@ struct Tensor
std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); }
std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return (mDesc.GetElementSpaceSize() + 1) / 2;
}
else
{
return mDesc.GetElementSpaceSize();
}
}
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
......@@ -469,29 +488,64 @@ struct Tensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
return mDesc.GetOffsetFromMultiIndex(is...);
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
else
{
return mDesc.GetOffsetFromMultiIndex(is...);
}
}
template <typename... Is>
T& operator()(Is... is)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
}
template <typename... Is>
const T& operator()(Is... is) const
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
}
T& operator()(std::vector<std::size_t> idx)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
}
const T& operator()(std::vector<std::size_t> idx) const
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
}
typename Data::iterator begin() { return mData.begin(); }
......
......@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t>
}
};
template <>
struct GeneratorTensor_1<ck::pk_i4_t>
{
int8_t value = 1;
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int t = value + 8;
ck::pk_i4_t r = ((t << 4) + t) & 0xff;
return r;
}
};
template <typename T>
struct GeneratorTensor_2
{
......@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t>
}
};
template <>
struct GeneratorTensor_2<ck::pk_i4_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int hi = std::rand() % (max_value - min_value) + min_value + 8;
int lo = std::rand() % (max_value - min_value) + min_value + 8;
ck::pk_i4_t r = ((hi << 4) + lo) & 0xff;
return r;
}
};
#if defined CK_ENABLE_FP8
template <>
struct GeneratorTensor_2<ck::f8_t>
......
......@@ -167,7 +167,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
typename enable_if<(has_same_scalar_type<S, X>::value || !is_native_type<S>()) &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr X GetAsType(Idx) const
......@@ -201,7 +201,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
typename enable_if<(has_same_scalar_type<S, X>::value || !is_native_type<S>()) &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Idx, X x)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp"
namespace ck {
enum struct BlockGemmPipelineVersion
{
v1, // Naive
v2, // Mem
v3, // Comp
v4, // Comp, double lds buffer
v5, // Comp, double global prefetch register buffer
};
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
constexpr auto BlockGemmPipeline_Selector()
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
return BlockwiseGemmXdlops_pipeline_v1_b_scale<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
return BlockwiseGemmXdlops_pipeline_v2_b_scale<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return BlockwiseGemmXdlops_pipeline_v3_b_scale<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
return BlockwiseGemmXdlops_pipeline_v4_b_scale<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5)
{
return BlockwiseGemmXdlops_pipeline_v5<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
}
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment