Unverified Commit ec959387 authored by rocking's avatar rocking Committed by GitHub
Browse files

Merge branch 'develop' into ck_tile/fmha_receipt_aiter

parents c1e2fef7 0e5e29c4
...@@ -7,22 +7,20 @@ export CK_REPEAT=1 ...@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1' COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
run_fp16_tests() { run_tests() {
for batch in 1 2; do for m in 128 1024; do
for m in 128 1024; do for n in 128 2048; do
for n in 128 2048; do for k in 64 128; do
for k in 32 64; do
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS if [ $? -eq 0 ]; then
if [ $? -eq 0 ]; then echo "Success: Test with m=$m, n=$n, k=$k executed successfully."
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." else
else echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly."
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." # Optionally, exit or break if you need to halt further execution
# Optionally, exit or break if you need to halt further execution # exit 1
# exit 1 fi
fi
done
done done
done done
done done
...@@ -30,6 +28,9 @@ run_fp16_tests() { ...@@ -30,6 +28,9 @@ run_fp16_tests() {
set -x set -x
run_fp16_tests run_tests "fp16"
run_tests "bf16"
run_tests "fp8"
run_tests "bf8"
set +x set +x
...@@ -7,22 +7,20 @@ export CK_REPEAT=1 ...@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1' COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
run_fp16_tests() { run_tests() {
for batch in 1 2; do for m in 512 1024; do
for m in 128 1024; do for n in 512 2048; do
for n in 128 2048; do for k in 512 1024; do
for k in 32 64; do
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS if [ $? -eq 0 ]; then
if [ $? -eq 0 ]; then echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." else
else echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." # Optionally, exit or break if you need to halt further execution
# Optionally, exit or break if you need to halt further execution # exit 1
# exit 1 fi
fi
done
done done
done done
done done
...@@ -30,6 +28,9 @@ run_fp16_tests() { ...@@ -30,6 +28,9 @@ run_fp16_tests() {
set -x set -x
run_fp16_tests run_tests "fp16"
run_tests "bf16"
run_tests "fp8"
run_tests "bf8"
set +x set +x
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -12,7 +12,13 @@ ...@@ -12,7 +12,13 @@
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
template <typename ALayout, typename BLayout, typename CLayout> template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
...@@ -29,10 +35,28 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -29,10 +35,28 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) constexpr bool DoubleSmemBuffer = false;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler // Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
...@@ -42,13 +66,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -42,13 +66,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = true;
#endif #endif
constexpr bool kPadM = false; constexpr bool kPadM = false;
constexpr bool kPadN = false; constexpr bool kPadN = false;
constexpr bool kPadK = false; constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1; constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
// =============================================== // ===============================================
...@@ -56,13 +86,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -56,13 +86,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>; using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem = using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
...@@ -85,14 +120,27 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -85,14 +120,27 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
BDataType, BDataType,
AccDataType, AccDataType,
GemmShape, GemmShape,
Traits, GemmUniversalTraits,
scheduler, scheduler,
has_hot_loop_v, has_hot_loop_v,
tail_number_v>; tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>; using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using GemmEpilogue = ck_tile::CShuffleEpilogue<
auto kargs = Kernel::MakeKernelArgs(args); ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
...@@ -117,6 +165,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -117,6 +165,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
if(has_hot_loop) if(has_hot_loop)
{ {
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven // Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One) if(tail_num == ck_tile::TailNumber::One)
{ {
...@@ -177,6 +240,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -177,6 +240,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}); ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
} }
} }
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#endif
} }
else else
{ {
...@@ -201,4 +276,115 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -201,4 +276,115 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R")
{
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else if(a_layout == "C" && b_layout == "C")
{
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else if(a_layout == "C" && b_layout == "R")
{
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -33,7 +33,7 @@ target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS}) ...@@ -33,7 +33,7 @@ target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS) set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress)
target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS})
......
...@@ -37,7 +37,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [ ...@@ -37,7 +37,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
DATA_TYPE_MAP = {'fp32' : 'float', DATA_TYPE_MAP = {'fp32' : 'float',
'fp16' : 'ck_tile::fp16_t', 'fp16' : 'ck_tile::fp16_t',
'bf16' : 'ck_tile::bf16_t', 'bf16' : 'ck_tile::bf16_t',
'int8' : 'ck_tile::int8_t'} 'int8' : 'ck_tile::int8_t',
'fp8' : 'ck_tile::fp8_t'}
def BOOL_MAP(b_) -> str: def BOOL_MAP(b_) -> str:
if b_: if b_:
...@@ -477,12 +478,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, ...@@ -477,12 +478,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
h_traits = rmsnorm_fwd_codegen.h_traits h_traits = rmsnorm_fwd_codegen.h_traits
h_instance = rmsnorm_fwd_codegen.h_instance h_instance = rmsnorm_fwd_codegen.h_instance
dynamic_quant_out_dtype = ['int8'] dynamic_quant_out_dtype = ['int8', 'fp8']
# some predefined support range # some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict # (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list = [('fp32,fp32')] scale_list = [('fp32,fp32')]
dtype_list = [('fp16,fp16'), ('bf16,bf16'), dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out ('fp16,int8'), ('bf16,int8'),
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 out
#fused_add_list = [0, 1, 2] #fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
fused_add_list = [0, 1] fused_add_list = [0, 1]
......
...@@ -105,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -105,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sy = "fp32"; prec_sy = "fp32";
} }
if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8") if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8" && prec_o != "fp8")
{ {
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl; std::cout
<< "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases."
<< std::endl;
return false; return false;
} }
...@@ -248,7 +250,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -248,7 +250,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
absmax = a > absmax ? a : absmax; absmax = a > absmax ? a : absmax;
} }
// printf("cpu:absmax:%f\n", absmax); // printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0); constexpr ComputeDataType kMaxY =
std::is_same<YDataType, ck_tile::fp8_t>::value ? 240.0
: std::is_same<YDataType, ck_tile::int8_t>::value ? 127.0
: 0.0;
ComputeDataType y_scale = absmax / kMaxY;
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale); y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++) for(int n_ = 0; n_ < N_; n_++)
{ {
...@@ -400,6 +406,16 @@ int main(int argc, char* argv[]) ...@@ -400,6 +406,16 @@ int main(int argc, char* argv[])
{ {
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_rms)
{
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_rms)
{
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}
return -3; return -3;
} }
#!/bin/sh #!/bin/sh
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8"; do for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"; do
for pr_i in "fp16" "bf16" ; do for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do for fadd in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
...@@ -27,7 +27,7 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 ...@@ -27,7 +27,7 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 #$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done done
done done
......
...@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[]) ...@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
.insert("k", "4", "topk") .insert("k", "4", "topk")
.insert("unit", "32", "unit_size") .insert("unit", "32", "unit_size")
.insert("moe_buf_size", "0", "moe_buf_size") .insert("moe_buf_size", "0", "moe_buf_size")
.insert("local_eid",
"-1",
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
"please make sure eid is in ascending order!")
.insert("seed", "-1", "seed to be used, -1 means random every time") .insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name") .insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
...@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int kname = args.get_int("kname"); int kname = args.get_int("kname");
int warmup = args.get_int("warmup"); int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat"); int repeat = args.get_int("repeat");
int max_output_ids = int max_output_ids =
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
...@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
return false; return false;
} }
bool local_expert_masking = args.get_str("local_eid") != "-1";
auto local_expert_masking_host = [&]() {
if(local_expert_masking)
{
auto local_eid = args.get_int_vec("local_eid");
// std::vector<int> v_ {num_experts, 0};
ck_tile::HostTensor<IndexType> v_{{num_experts}};
v_.SetZero();
for(auto eid : local_eid)
{
if(eid >= num_experts)
{
throw std::runtime_error(
"local_eid larger than number of expert, please check");
}
v_.mData[eid] = 1;
}
return v_;
}
else
// return std::vector<int>{};
return ck_tile::HostTensor<IndexType>{{1}};
}();
// tokens already considered batch size // tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1}); ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1}); ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
...@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_expert_ids_host.get_element_space_size_in_bytes()); sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem local_expert_masking_dev(
local_expert_masking_host.get_element_space_size_in_bytes());
topk_ids_dev.ToDevice(topk_ids_host.data()); topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data()); weights_dev.ToDevice(weights_host.data());
...@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
{ {
moe_buf_dev.ToDevice(moe_buf_host.data()); moe_buf_dev.ToDevice(moe_buf_host.data());
} }
if(local_expert_masking)
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
moe_sorting_trait trait{index_prec, weight_prec}; moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(),
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
: nullptr,
sorted_ids_dev.GetDeviceBuffer(), sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(), sorted_expert_ids_dev.GetDeviceBuffer(),
...@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
warmup, warmup,
repeat}; repeat};
auto ms = moe_sorting(trait, karg, sc); auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ", printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ",
index_prec.c_str(), index_prec.c_str(),
weight_prec.c_str(), weight_prec.c_str(),
tokens, tokens,
num_experts, num_experts,
topk, topk);
ms);
if(local_expert_masking)
{
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
}
if(ms < 0) if(ms < 0)
printf("not supported\n"); printf("not supported\n");
else
printf("ms:%f, ", ms);
fflush(stdout); fflush(stdout);
if(ms < 0) if(ms < 0)
{ {
...@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int32_t ref_total_tokens_post_pad = 0; int32_t ref_total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host, ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host, weights_host,
local_expert_masking_host,
sorted_ids_ref, sorted_ids_ref,
sorted_weights_ref, sorted_weights_ref,
sorted_expert_ids_ref, sorted_expert_ids_ref,
ref_total_tokens_post_pad, ref_total_tokens_post_pad,
num_experts, num_experts,
unit_size); unit_size,
local_expert_masking);
rtn &= ck_tile::check_err( rtn &= ck_tile::check_err(
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
rtn &= ck_tile::check_err(sorted_weights_host, rtn &= ck_tile::check_err(sorted_weights_host,
...@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
} }
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
printf("total_tokens_post_pad:%d(%d), ",
ref_total_tokens_post_pad,
sorted_id_cnt_host.mData[0]);
} }
printf("valid:%s\n", rtn ? "y" : "n"); printf("valid:%s", rtn ? "y" : "n");
fflush(stdout);
if(!rtn)
printf(", (%d)", seed);
printf("\n");
fflush(stdout); fflush(stdout);
return rtn; return rtn;
} }
......
...@@ -3,6 +3,12 @@ ...@@ -3,6 +3,12 @@
#include "moe_sorting_api.hpp" #include "moe_sorting_api.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
...@@ -17,6 +23,67 @@ ...@@ -17,6 +23,67 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \ if(a.num_experts <= 8) \
{ \ { \
...@@ -38,11 +105,13 @@ ...@@ -38,11 +105,13 @@
{ \ { \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
} }
#endif
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
{ {
#if !MOE_SORTING_USE_EX_KERNEL
if(a.num_experts > 127) if(a.num_experts > 127)
{ {
printf("lds size exceed, only support experts <127 \n"); printf("lds size exceed, only support experts <127 \n");
...@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi ...@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
} }
#else
using index_t = ck_tile::index_t;
using ms_weight_type = float;
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
auto sub_token_ = r_ - 2;
r_ = (r_ - 2) / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
(void)c_;
MOE_SORTING_DISPATCH_EMASK_(r_);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
} }
return -1; return -1;
} }
...@@ -10,7 +10,8 @@ ...@@ -10,7 +10,8 @@
struct moe_sorting_trait struct moe_sorting_trait
{ {
std::string index_type; std::string index_type;
std::string weight_type; // currently always float std::string weight_type; // currently always float
bool local_expert_masking; // if mask experts as local expert
}; };
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
......
...@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11 ...@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1 $EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1 $EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13 $EXE -t=333 -e=99 -k=13
$EXE -t=11 -e=256 -k=5
$EXE -t=64 -e=455 -k=8
$EXE -t=777 -e=802 -k=99
$EXE -t=4097 -e=906 -k=51
$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144 $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
...@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator: ...@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator:
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// //
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) // max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
......
...@@ -8,14 +8,15 @@ ...@@ -8,14 +8,15 @@
struct fused_moe_args struct fused_moe_args
{ {
const void* a_ptr; // [m, k], input token const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w]) const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token (no need to do zeroing) const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
void* o_ptr; // [m, k], output token (no need to do zeroing)
const void* topk_ids_ptr; // [tokens, topk] const void* topk_ids_ptr; // [tokens, topk]
const void* topk_weight_ptr; // [tokens, topk] const void* topk_weight_ptr; // [tokens, topk]
...@@ -48,6 +49,8 @@ struct fused_moe_traits ...@@ -48,6 +49,8 @@ struct fused_moe_traits
int activation; // 0:gelu, 1:silu int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1 int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
bool local_expert_masking; // if mask experts as local expert
}; };
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&); float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
...@@ -10,7 +10,8 @@ ...@@ -10,7 +10,8 @@
struct fused_moesorting_trait struct fused_moesorting_trait
{ {
std::string index_type; std::string index_type;
std::string weight_type; // currently always float std::string weight_type; // currently always float
bool local_expert_masking; // if mask experts as local expert
}; };
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs
......
...@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf ...@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
return 1; return 1;
}(); }();
auto t0 = fused_moesorting_trait{"int32", "fp32"}; auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
auto a0 = fused_moesorting_args{ auto a0 = fused_moesorting_args{
a.topk_ids_ptr, // const void* p_topk_ids; a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights; a.topk_weight_ptr, // const void* p_weights;
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
a.sorted_token_ids_ptr, // void* p_sorted_token_ids; a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights; a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
......
...@@ -3,6 +3,12 @@ ...@@ -3,6 +3,12 @@
#include "fused_moesorting.hpp" #include "fused_moesorting.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
...@@ -17,6 +23,67 @@ ...@@ -17,6 +23,67 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \ if(a.num_experts <= 8) \
{ \ { \
...@@ -38,11 +105,13 @@ ...@@ -38,11 +105,13 @@
{ \ { \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
} }
#endif
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
{ {
#if !MOE_SORTING_USE_EX_KERNEL
if(a.num_experts > 127) if(a.num_experts > 127)
{ {
printf("lds size exceed, only support experts <127 \n"); printf("lds size exceed, only support experts <127 \n");
...@@ -83,6 +152,19 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til ...@@ -83,6 +152,19 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
} }
#else
using index_t = ck_tile::index_t;
using ms_weight_type = float;
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
auto sub_token_ = r_ - 2;
r_ = (r_ - 2) / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
(void)c_;
MOE_SORTING_DISPATCH_EMASK_(r_);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
} }
return -1; return -1;
} }
...@@ -140,28 +140,29 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -140,28 +140,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t activation = arg_parser.get_int("act"); ck_tile::index_t activation = arg_parser.get_int("act");
if(stride < 0) if(stride < 0)
stride = hidden_size; stride = hidden_size;
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w"); std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_st = arg_parser.get_str("prec_st"); std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sw = arg_parser.get_str("prec_sw"); std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq"); std::string prec_sq = arg_parser.get_str("prec_sq");
std::string prec_kw = arg_parser.get_str("prec_kw"); std::string prec_kw = arg_parser.get_str("prec_kw");
prec_st = (prec_st == "auto") ? "fp32" : prec_st; prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
int kname = arg_parser.get_int("kname"); int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
int fused_quant = arg_parser.get_int("fquant"); int fused_quant = arg_parser.get_int("fquant");
int gate_only = arg_parser.get_int("gate_only"); int gate_only = arg_parser.get_int("gate_only");
int api = arg_parser.get_int("api"); int api = arg_parser.get_int("api");
int balance = arg_parser.get_int("balance"); int balance = arg_parser.get_int("balance");
int tp = arg_parser.get_int("tp"); int tp = arg_parser.get_int("tp");
int init = arg_parser.get_int("init"); int init = arg_parser.get_int("init");
uint32_t seed = arg_parser.get_uint32("seed"); uint32_t seed = arg_parser.get_uint32("seed");
bool local_expert_masking = false; // TODO...
// w0 (Gate+Up or Gate only, N size) // w0 (Gate+Up or Gate only, N size)
ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp; ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp;
...@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
ck_tile::HostTensor<IndexDataType> local_expert_mask_host({experts});
int max_num_tokens_padded = topk * tokens + experts * block_m - topk; int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded}); ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
...@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem sg_buf(sg_host); ck_tile::DeviceMem sg_buf(sg_host);
ck_tile::DeviceMem sd_buf(sd_host); ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host); ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem local_expert_mask_buf(local_expert_mask_host);
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem topk_ids_buf(topk_ids_host); ck_tile::DeviceMem topk_ids_buf(topk_ids_host);
...@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
block_m, block_m,
activation, activation,
gate_only, gate_only,
fused_quant}; fused_quant,
local_expert_masking};
fused_moe_args args{a_buf.GetDeviceBuffer(), fused_moe_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
...@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
: nullptr,
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
topk_ids_buf.GetDeviceBuffer(), topk_ids_buf.GetDeviceBuffer(),
topk_weight_buf.GetDeviceBuffer(), topk_weight_buf.GetDeviceBuffer(),
...@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host, topk_ids_host,
topk_weight_host, topk_weight_host,
local_expert_mask_host,
sorted_token_ids_host, sorted_token_ids_host,
sorted_weight_host, sorted_weight_host,
sorted_expert_ids_host, sorted_expert_ids_host,
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m,
local_expert_masking);
if(activation == 0) if(activation == 0)
{ {
CPU_FUSED_MOE(ck_tile::element_wise::Gelu); CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
...@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host, topk_ids_host,
topk_weight_host, topk_weight_host,
local_expert_mask_host,
sorted_token_ids_host, sorted_token_ids_host,
sorted_weight_host, sorted_weight_host,
sorted_expert_ids_host, sorted_expert_ids_host,
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m,
local_expert_masking);
// done, preparing GPU buffer // done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host); ck_tile::DeviceMem a_buf(a_host);
......
...@@ -19,12 +19,9 @@ template <typename ALayout, typename BLayout, typename CLayout> ...@@ -19,12 +19,9 @@ template <typename ALayout, typename BLayout, typename CLayout>
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{ {
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false; constexpr bool kPadM = false;
constexpr bool kPadN = false; constexpr bool kPadN = false;
constexpr bool kPadK = false; constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -41,40 +38,31 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -41,40 +38,31 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape = using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using CodegenGemmPipeline = using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>; ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
...@@ -91,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -91,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl; << std::endl;
} }
......
...@@ -39,7 +39,7 @@ auto create_args(int argc, char* argv[]) ...@@ -39,7 +39,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride") .insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride") .insert("stride_c", "0", "Tensor C stride")
.insert("a_layout", "R", "A tensor data layout - Row by default") .insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("batch_stride_a", "32768", "Batch A stride") .insert("batch_stride_a", "32768", "Batch A stride")
.insert("batch_stride_b", "16384", "Batch B stride") .insert("batch_stride_b", "16384", "Batch B stride")
......
...@@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc,
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl; << std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
else if(arg_parser.get_int("v") == 2) else if(arg_parser.get_int("v") == 2)
{ {
...@@ -301,11 +301,11 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -301,11 +301,11 @@ int run_batched_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout"); std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout"); std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R") // if(a_layout == "R" && b_layout == "R")
{ // {
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); // return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
} // }
else if(a_layout == "R" && b_layout == "C") if(a_layout == "R" && b_layout == "C")
{ {
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment