Commit abd2755a authored by ThomasNing's avatar ThomasNing
Browse files

Merge branch 'develop' into moe_cross_reduce

parents b74918bc 888317e6
...@@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl ...@@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>; fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} = using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}>;
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
...@@ -355,4 +353,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im ...@@ -355,4 +353,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels: for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
\ No newline at end of file
...@@ -96,9 +96,7 @@ using fmha_epilogue = ...@@ -96,9 +96,7 @@ using fmha_epilogue =
{F_spad}, {F_dvpad}>>; {F_spad}, {F_dvpad}>>;
using fmha_kernel = using fmha_kernel =
ck_tile::FmhaFwdSplitKVKernel<ck_tile::FmhaFwdSplitKVTilePartitioner<fmha_shape>, ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
...@@ -176,11 +174,7 @@ using fmha_epilogue = ...@@ -176,11 +174,7 @@ using fmha_epilogue =
false, false>>; false, false>>;
using fmha_kernel = using fmha_kernel =
ck_tile::FmhaFwdSplitKVCombineKernel< ck_tile::FmhaFwdSplitKVCombineKernel<fmha_pipeline, fmha_epilogue>;
ck_tile::FmhaFwdSplitKVCombineTilePartitioner<
fmha_pipeline_problem::kM0, fmha_pipeline_problem::kN1>,
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
...@@ -261,7 +255,7 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F ...@@ -261,7 +255,7 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F
static_assert({F_bn1} % 32 == 0); static_assert({F_bn1} % 32 == 0);
if (t.has_lse) {{ if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{ if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
return -1; return -1;
}} else {{ }} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
...@@ -614,7 +608,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d ...@@ -614,7 +608,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdSplitKVCombineTileSize(32, -1), '64' : FmhaFwdSplitKVCombineTileSize(32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1),
'256' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1),
} }
......
...@@ -1131,15 +1131,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1131,15 +1131,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
// NOTE: use gpu to do validation // NOTE: use gpu to do validation
ck_tile::naive_attention_fwd_traits naive_t; ck_tile::naive_attention_fwd_traits naive_t;
naive_t.q_type = data_type; naive_t.q_type = data_type;
naive_t.k_type = data_type; naive_t.k_type = data_type;
naive_t.v_type = data_type; naive_t.v_type = data_type;
naive_t.o_type = data_type; naive_t.o_type = data_type;
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd"; naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd"; naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd"; naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd"; naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
naive_t.variation = 0; // TODO? naive_t.variation = 0; // TODO?
naive_t.quant_algo = 0;
ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes());
......
...@@ -400,8 +400,18 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -400,8 +400,18 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
} }
}(); }();
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); if constexpr(FmhaKernel::kIsGroupMode)
return ck_tile::make_tuple(kargs, grids); {
dim3 grids = FmhaKernel::GridSize(
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
return ck_tile::make_tuple(kargs, grids);
}
else
{
dim3 grids =
FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
return ck_tile::make_tuple(kargs, grids);
}
} }
template <typename Kernel> template <typename Kernel>
......
This diff is collapsed.
...@@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 ...@@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=9120
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 #$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done done
done done
......
...@@ -54,8 +54,7 @@ using CDataType = Types::CDataType; ...@@ -54,8 +54,7 @@ using CDataType = Types::CDataType;
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size") arg_parser.insert("m", "3840", "m dimension")
.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension") .insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default") .insert("a_layout", "R", "A tensor data layout - Row by default")
...@@ -68,7 +67,8 @@ auto create_args(int argc, char* argv[]) ...@@ -68,7 +67,8 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
......
...@@ -64,9 +64,9 @@ int run_gemm_example_with_layouts(int argc, ...@@ -64,9 +64,9 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t batch_size = arg_parser.get_int("b"); ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
using namespace ck_tile::literals; using namespace ck_tile::literals;
...@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc,
stride_A, stride_A,
stride_B, stride_B,
stride_C, stride_C,
batch_size, kbatch,
n_warmup, n_warmup,
n_repeat); n_repeat);
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#endif #endif
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler // Memory friendly for Interwave scheduler
...@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
#endif #endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
...@@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKernelArgs(args);
args.p_b,
args.p_c, const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs)) if(!Kernel::IsSupportedArgument(kargs))
......
...@@ -3,18 +3,42 @@ ...@@ -3,18 +3,42 @@
#include "moe_sorting_api.hpp" #include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using ms_problem = \
auto kargs = kernel::MakeKargs(a); \ ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
const dim3 grids = kernel::GridSize(a); \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
const dim3 blocks = kernel::BlockSize(a); \ auto kargs = kernel::MakeKargs(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \ const dim3 grids = kernel::GridSize(a); \
float ave_time = ck_tile::launch_kernel( \ const dim3 blocks = kernel::BlockSize(a); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
...@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi ...@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
case(6): { case(6): {
MOE_SORTING_DISPATCH(6); MOE_SORTING_DISPATCH(6);
} }
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): { case(8): {
MOE_SORTING_DISPATCH(8); MOE_SORTING_DISPATCH(8);
} }
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): { case(10): {
MOE_SORTING_DISPATCH(10); MOE_SORTING_DISPATCH(10);
} }
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: { default: {
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
......
...@@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19 ...@@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19
$EXE -t=71 -e=11 -k=11 $EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1 $EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1 $EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13 $EXE -t=333 -e=99 -k=13
\ No newline at end of file $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
...@@ -3,18 +3,42 @@ ...@@ -3,18 +3,42 @@
#include "fused_moesorting.hpp" #include "fused_moesorting.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using ms_problem = \
auto kargs = kernel::MakeKargs(a); \ ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
const dim3 grids = kernel::GridSize(a); \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
const dim3 blocks = kernel::BlockSize(a); \ auto kargs = kernel::MakeKargs(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \ const dim3 grids = kernel::GridSize(a); \
float ave_time = ck_tile::launch_kernel( \ const dim3 blocks = kernel::BlockSize(a); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
...@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til ...@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
case(6): { case(6): {
MOE_SORTING_DISPATCH(6); MOE_SORTING_DISPATCH(6);
} }
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): { case(8): {
MOE_SORTING_DISPATCH(8); MOE_SORTING_DISPATCH(8);
} }
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): { case(10): {
MOE_SORTING_DISPATCH(10); MOE_SORTING_DISPATCH(10);
} }
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: { default: {
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
......
...@@ -70,20 +70,25 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -70,20 +70,25 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>; using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
......
...@@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[]) ...@@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
......
...@@ -17,6 +17,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -17,6 +17,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t batch_stride_B, ck_tile::index_t batch_stride_B,
ck_tile::index_t batch_stride_C, ck_tile::index_t batch_stride_C,
ck_tile::index_t batch_count, ck_tile::index_t batch_count,
ck_tile::index_t kbatch,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
...@@ -24,6 +25,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -24,6 +25,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
...@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b"); ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c"); ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
ck_tile::index_t batch_count = arg_parser.get_int("batch_count"); ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
...@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_B, batch_stride_B,
batch_stride_C, batch_stride_C,
batch_count, batch_count,
kbatch,
n_warmup, n_warmup,
n_repeat); n_repeat);
......
...@@ -115,8 +115,8 @@ ...@@ -115,8 +115,8 @@
#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@ #cmakedefine CK_USE_GFX94 @CK_USE_GFX94@
#endif #endif
#ifndef DCK_USE_OCP_FP8 #ifndef CK_USE_OCP_FP8
#cmakedefine DCK_USE_OCP_FP8 @DCK_USE_OCP_FP8@ #cmakedefine CK_USE_OCP_FP8 @CK_USE_OCP_FP8@
#endif #endif
#ifndef CK_USE_FNUZ_FP8 #ifndef CK_USE_FNUZ_FP8
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) ...@@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
else else
os << delim; os << delim;
if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck::bf8_t>) using RangeType = ck::remove_cvref_t<decltype(v)>;
if constexpr(std::is_same_v<RangeType, ck::f8_t> || std::is_same_v<RangeType, ck::bf8_t> ||
std::is_same_v<RangeType, ck::bhalf_t>)
{ {
os << ck::type_convert<float>(v); os << ck::type_convert<float>(v);
} }
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t>)
{
const auto packed_floats = ck::type_convert<ck::float2_t>(v);
const ck::vector_type<float, 2> vector_of_floats{packed_floats};
os << vector_of_floats.template AsType<float>()[ck::Number<0>{}] << delim
<< vector_of_floats.template AsType<float>()[ck::Number<1>{}];
}
else else
{ {
os << static_cast<T>(v); os << static_cast<T>(v);
...@@ -266,18 +275,18 @@ struct Tensor ...@@ -266,18 +275,18 @@ struct Tensor
using Data = std::vector<T>; using Data = std::vector<T>;
template <typename X> template <typename X>
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(GetElementSpaceSize())
{ {
} }
template <typename X, typename Y> template <typename X, typename Y>
Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides) Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize()) : mDesc(lens, strides), mData(GetElementSpaceSize())
{ {
} }
template <typename Lengths> template <typename Lengths>
Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize())
{ {
} }
...@@ -287,7 +296,7 @@ struct Tensor ...@@ -287,7 +296,7 @@ struct Tensor
{ {
} }
Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {} Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {}
template <typename OutT> template <typename OutT>
Tensor<OutT> CopyAsType() const Tensor<OutT> CopyAsType() const
...@@ -322,7 +331,17 @@ struct Tensor ...@@ -322,7 +331,17 @@ struct Tensor
std::size_t GetElementSize() const { return mDesc.GetElementSize(); } std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); } std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return (mDesc.GetElementSpaceSize() + 1) / 2;
}
else
{
return mDesc.GetElementSpaceSize();
}
}
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
...@@ -469,29 +488,64 @@ struct Tensor ...@@ -469,29 +488,64 @@ struct Tensor
template <typename... Is> template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const std::size_t GetOffsetFromMultiIndex(Is... is) const
{ {
return mDesc.GetOffsetFromMultiIndex(is...); if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
else
{
return mDesc.GetOffsetFromMultiIndex(is...);
}
} }
template <typename... Is> template <typename... Is>
T& operator()(Is... is) T& operator()(Is... is)
{ {
return mData[mDesc.GetOffsetFromMultiIndex(is...)]; if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
} }
template <typename... Is> template <typename... Is>
const T& operator()(Is... is) const const T& operator()(Is... is) const
{ {
return mData[mDesc.GetOffsetFromMultiIndex(is...)]; if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
} }
T& operator()(std::vector<std::size_t> idx) T& operator()(std::vector<std::size_t> idx)
{ {
return mData[mDesc.GetOffsetFromMultiIndex(idx)]; if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
} }
const T& operator()(std::vector<std::size_t> idx) const const T& operator()(std::vector<std::size_t> idx) const
{ {
return mData[mDesc.GetOffsetFromMultiIndex(idx)]; if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
} }
typename Data::iterator begin() { return mData.begin(); } typename Data::iterator begin() { return mData.begin(); }
......
...@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t> ...@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t>
} }
}; };
template <>
struct GeneratorTensor_1<ck::pk_i4_t>
{
int8_t value = 1;
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int t = value + 8;
ck::pk_i4_t r = ((t << 4) + t) & 0xff;
return r;
}
};
template <typename T> template <typename T>
struct GeneratorTensor_2 struct GeneratorTensor_2
{ {
...@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t> ...@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t>
} }
}; };
template <>
struct GeneratorTensor_2<ck::pk_i4_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int hi = std::rand() % (max_value - min_value) + min_value + 8;
int lo = std::rand() % (max_value - min_value) + min_value + 8;
ck::pk_i4_t r = ((hi << 4) + lo) & 0xff;
return r;
}
};
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
template <> template <>
struct GeneratorTensor_2<ck::f8_t> struct GeneratorTensor_2<ck::f8_t>
......
...@@ -167,7 +167,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -167,7 +167,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X // Idx is for S, not X. Idx should be aligned with X
template <typename X, template <typename X,
typename Idx, typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value && typename enable_if<(has_same_scalar_type<S, X>::value || !is_native_type<S>()) &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_, is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr X GetAsType(Idx) const __host__ __device__ constexpr X GetAsType(Idx) const
...@@ -201,7 +201,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -201,7 +201,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X // Idx is for S, not X. Idx should be aligned with X
template <typename X, template <typename X,
typename Idx, typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value && typename enable_if<(has_same_scalar_type<S, X>::value || !is_native_type<S>()) &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_, is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr void SetAsType(Idx, X x) __host__ __device__ constexpr void SetAsType(Idx, X x)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp"
namespace ck {
enum struct BlockGemmPipelineVersion
{
v1, // Naive
v2, // Mem
v3, // Comp
v4, // Comp, double lds buffer
v5, // Comp, double global prefetch register buffer
};
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
constexpr auto BlockGemmPipeline_Selector()
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
return BlockwiseGemmXdlops_pipeline_v1_b_scale<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
return BlockwiseGemmXdlops_pipeline_v2_b_scale<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return BlockwiseGemmXdlops_pipeline_v3_b_scale<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
return BlockwiseGemmXdlops_pipeline_v4_b_scale<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5)
{
return BlockwiseGemmXdlops_pipeline_v5<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
}
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment