Commit f23a2e2a authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

resolved conflicts

parents f3eb5a18 c0adab48
#!/bin/sh #!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=0 VALID=1
for b_matrix_layout in "R" "C"; do for b_matrix_layout in "C"; do
for m in "64" "512" "1024" "2048"; do for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do for k in "512" "1024" "2048"; do
$EXE -prec=fp16 -b=1 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID $EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done done
done done
done done
......
#!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do
for k in "512" "1024" "2048"; do
$EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done
\ No newline at end of file
#!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do
for k in "512" "1024" "2048"; do
$EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done
\ No newline at end of file
#!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do
for k in "512" "1024" "2048"; do
$EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done
\ No newline at end of file
...@@ -7,22 +7,20 @@ export CK_REPEAT=1 ...@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1' COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
run_fp16_tests() { run_tests() {
for batch in 1 2; do for m in 128 1024; do
for m in 128 1024; do for n in 128 2048; do
for n in 128 2048; do for k in 64 128; do
for k in 32 64; do
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS if [ $? -eq 0 ]; then
if [ $? -eq 0 ]; then echo "Success: Test with m=$m, n=$n, k=$k executed successfully."
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." else
else echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly."
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." # Optionally, exit or break if you need to halt further execution
# Optionally, exit or break if you need to halt further execution # exit 1
# exit 1 fi
fi
done
done done
done done
done done
...@@ -30,6 +28,9 @@ run_fp16_tests() { ...@@ -30,6 +28,9 @@ run_fp16_tests() {
set -x set -x
run_fp16_tests run_tests "fp16"
run_tests "bf16"
run_tests "fp8"
run_tests "bf8"
set +x set +x
...@@ -7,22 +7,20 @@ export CK_REPEAT=1 ...@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1' COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
run_fp16_tests() { run_tests() {
for batch in 1 2; do for m in 512 1024; do
for m in 128 1024; do for n in 512 2048; do
for n in 128 2048; do for k in 512 1024; do
for k in 32 64; do
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS if [ $? -eq 0 ]; then
if [ $? -eq 0 ]; then echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." else
else echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." # Optionally, exit or break if you need to halt further execution
# Optionally, exit or break if you need to halt further execution # exit 1
# exit 1 fi
fi
done
done done
done done
done done
...@@ -30,6 +28,9 @@ run_fp16_tests() { ...@@ -30,6 +28,9 @@ run_fp16_tests() {
set -x set -x
run_fp16_tests run_tests "fp16"
run_tests "bf16"
run_tests "fp8"
run_tests "bf8"
set +x set +x
...@@ -23,24 +23,101 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -23,24 +23,101 @@ int run_gemm_example(int argc, char* argv[])
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout"); std::string data_type = arg_parser.get_str("prec");
std::string b_layout = arg_parser.get_str("b_layout"); std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R") if(a_layout == "R" && b_layout == "R")
{ {
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
} }
else if(a_layout == "R" && b_layout == "C") else if(a_layout == "R" && b_layout == "C")
{ {
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
} }
else if(a_layout == "C" && b_layout == "C") else if(a_layout == "C" && b_layout == "C")
{ {
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
} }
else if(a_layout == "C" && b_layout == "R") else if(a_layout == "C" && b_layout == "R")
{ {
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
} }
else else
{ {
......
...@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[]) ...@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
.insert("k", "4", "topk") .insert("k", "4", "topk")
.insert("unit", "32", "unit_size") .insert("unit", "32", "unit_size")
.insert("moe_buf_size", "0", "moe_buf_size") .insert("moe_buf_size", "0", "moe_buf_size")
.insert("local_eid",
"-1",
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
"please make sure eid is in ascending order!")
.insert("seed", "-1", "seed to be used, -1 means random every time") .insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name") .insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
...@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int kname = args.get_int("kname"); int kname = args.get_int("kname");
int warmup = args.get_int("warmup"); int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat"); int repeat = args.get_int("repeat");
int max_output_ids = int max_output_ids =
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
...@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
return false; return false;
} }
bool local_expert_masking = args.get_str("local_eid") != "-1";
auto local_expert_masking_host = [&]() {
if(local_expert_masking)
{
auto local_eid = args.get_int_vec("local_eid");
// std::vector<int> v_ {num_experts, 0};
ck_tile::HostTensor<IndexType> v_{{num_experts}};
v_.SetZero();
for(auto eid : local_eid)
{
if(eid >= num_experts)
{
throw std::runtime_error(
"local_eid larger than number of expert, please check");
}
v_.mData[eid] = 1;
}
return v_;
}
else
// return std::vector<int>{};
return ck_tile::HostTensor<IndexType>{{1}};
}();
// tokens already considered batch size // tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1}); ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1}); ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
...@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_expert_ids_host.get_element_space_size_in_bytes()); sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem local_expert_masking_dev(
local_expert_masking_host.get_element_space_size_in_bytes());
topk_ids_dev.ToDevice(topk_ids_host.data()); topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data()); weights_dev.ToDevice(weights_host.data());
...@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
{ {
moe_buf_dev.ToDevice(moe_buf_host.data()); moe_buf_dev.ToDevice(moe_buf_host.data());
} }
if(local_expert_masking)
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
moe_sorting_trait trait{index_prec, weight_prec}; moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(),
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
: nullptr,
sorted_ids_dev.GetDeviceBuffer(), sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(), sorted_expert_ids_dev.GetDeviceBuffer(),
...@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
warmup, warmup,
repeat}; repeat};
auto ms = moe_sorting(trait, karg, sc); auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ", printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ",
index_prec.c_str(), index_prec.c_str(),
weight_prec.c_str(), weight_prec.c_str(),
tokens, tokens,
num_experts, num_experts,
topk, topk);
ms);
if(local_expert_masking)
{
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
}
if(ms < 0) if(ms < 0)
printf("not supported\n"); printf("not supported\n");
else
printf("ms:%f, ", ms);
fflush(stdout); fflush(stdout);
if(ms < 0) if(ms < 0)
{ {
...@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int32_t ref_total_tokens_post_pad = 0; int32_t ref_total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host, ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host, weights_host,
local_expert_masking_host,
sorted_ids_ref, sorted_ids_ref,
sorted_weights_ref, sorted_weights_ref,
sorted_expert_ids_ref, sorted_expert_ids_ref,
ref_total_tokens_post_pad, ref_total_tokens_post_pad,
num_experts, num_experts,
unit_size); unit_size,
local_expert_masking);
rtn &= ck_tile::check_err( rtn &= ck_tile::check_err(
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
rtn &= ck_tile::check_err(sorted_weights_host, rtn &= ck_tile::check_err(sorted_weights_host,
...@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
} }
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
printf("total_tokens_post_pad:%d(%d), ",
ref_total_tokens_post_pad,
sorted_id_cnt_host.mData[0]);
} }
printf("valid:%s\n", rtn ? "y" : "n"); printf("valid:%s", rtn ? "y" : "n");
fflush(stdout);
if(!rtn)
printf(", (%d)", seed);
printf("\n");
fflush(stdout); fflush(stdout);
return rtn; return rtn;
} }
......
...@@ -3,6 +3,12 @@ ...@@ -3,6 +3,12 @@
#include "moe_sorting_api.hpp" #include "moe_sorting_api.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
...@@ -17,6 +23,67 @@ ...@@ -17,6 +23,67 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \ if(a.num_experts <= 8) \
{ \ { \
...@@ -38,11 +105,13 @@ ...@@ -38,11 +105,13 @@
{ \ { \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
} }
#endif
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
{ {
#if !MOE_SORTING_USE_EX_KERNEL
if(a.num_experts > 127) if(a.num_experts > 127)
{ {
printf("lds size exceed, only support experts <127 \n"); printf("lds size exceed, only support experts <127 \n");
...@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi ...@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
} }
#else
using index_t = ck_tile::index_t;
using ms_weight_type = float;
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
auto sub_token_ = r_ - 2;
r_ = (r_ - 2) / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
(void)c_;
MOE_SORTING_DISPATCH_EMASK_(r_);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
} }
return -1; return -1;
} }
...@@ -10,7 +10,8 @@ ...@@ -10,7 +10,8 @@
struct moe_sorting_trait struct moe_sorting_trait
{ {
std::string index_type; std::string index_type;
std::string weight_type; // currently always float std::string weight_type; // currently always float
bool local_expert_masking; // if mask experts as local expert
}; };
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
......
...@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11 ...@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1 $EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1 $EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13 $EXE -t=333 -e=99 -k=13
$EXE -t=11 -e=256 -k=5
$EXE -t=64 -e=455 -k=8
$EXE -t=777 -e=802 -k=99
$EXE -t=4097 -e=906 -k=51
$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144 $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
...@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator: ...@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator:
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// //
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) // max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
......
...@@ -3,6 +3,12 @@ ...@@ -3,6 +3,12 @@
#include "fused_moesorting.hpp" #include "fused_moesorting.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
...@@ -17,6 +23,24 @@ ...@@ -17,6 +23,24 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
using ms_problem = \
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \ if(a.num_experts <= 8) \
{ \ { \
...@@ -38,11 +62,13 @@ ...@@ -38,11 +62,13 @@
{ \ { \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
} }
#endif
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
{ {
#if !MOE_SORTING_USE_EX_KERNEL
if(a.num_experts > 127) if(a.num_experts > 127)
{ {
printf("lds size exceed, only support experts <127 \n"); printf("lds size exceed, only support experts <127 \n");
...@@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til ...@@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
} }
#else
using index_t = ck_tile::index_t;
using ms_weight_type = float;
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
auto sub_token_ = r_ - 2;
r_ = (r_ - 2) / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
(void)c_;
if(is_sub_token_onshot)
{
if(r_ % 8 == 0)
{
MOE_SORTING_DISPATCH_(8, true);
}
else if(r_ % 4 == 0)
{
MOE_SORTING_DISPATCH_(4, true);
}
else if(r_ % 2 == 0)
{
MOE_SORTING_DISPATCH_(2, true);
}
else
{
MOE_SORTING_DISPATCH_(1, true);
}
}
else
{
if(r_ % 8 == 0)
{
MOE_SORTING_DISPATCH_(8, false);
}
else if(r_ % 4 == 0)
{
MOE_SORTING_DISPATCH_(4, false);
}
else if(r_ % 2 == 0)
{
MOE_SORTING_DISPATCH_(2, false);
}
else
{
MOE_SORTING_DISPATCH_(1, false);
}
}
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
} }
return -1; return -1;
} }
...@@ -19,12 +19,9 @@ template <typename ALayout, typename BLayout, typename CLayout> ...@@ -19,12 +19,9 @@ template <typename ALayout, typename BLayout, typename CLayout>
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{ {
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false; constexpr bool kPadM = false;
constexpr bool kPadN = false; constexpr bool kPadN = false;
constexpr bool kPadK = false; constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -41,38 +38,31 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -41,38 +38,31 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape = using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
...@@ -89,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -89,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl; << std::endl;
} }
......
...@@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc,
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl; << std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
else if(arg_parser.get_int("v") == 2) else if(arg_parser.get_int("v") == 2)
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -20,12 +20,9 @@ namespace { ...@@ -20,12 +20,9 @@ namespace {
struct GroupedGemmKernelParam struct GroupedGemmKernelParam
{ {
static const bool kPadM = false; static const bool kPadM = false;
static const bool kPadN = false; static const bool kPadN = false;
static const bool kPadK = false; static const bool kPadK = false;
static const bool kTilePermute = false;
static const ck_tile::index_t kOutputRank = 2;
static const int kBlockPerCu = 1; static const int kBlockPerCu = 1;
static const ck_tile::index_t M_Tile = 128; static const ck_tile::index_t M_Tile = 128;
...@@ -54,24 +51,6 @@ using CodegenGemmShape = ...@@ -54,24 +51,6 @@ using CodegenGemmShape =
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
template <typename CLayout>
using GemmEpilogue = std::conditional_t<
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN,
GroupedGemmKernelParam::kTilePermute,
GroupedGemmKernelParam::kOutputRank,
1,
0,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<AccDataType,
CDataType,
GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN>>>;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemmKernelParam::kPadM, using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN, GroupedGemmKernelParam::kPadN,
...@@ -92,10 +71,25 @@ template <typename ALayout, typename BLayout, typename CLayout> ...@@ -92,10 +71,25 @@ template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline = using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>; ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
template <typename ALayout, typename BLayout, typename CLayout>
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem<ALayout, BLayout, CLayout>::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemmKernelParam::M_Warp,
GroupedGemmKernelParam::N_Warp,
GroupedGemmKernelParam::M_Warp_Tile,
GroupedGemmKernelParam::N_Warp_Tile,
GroupedGemmKernelParam::K_Warp_Tile,
CodegenPipelineProblem<ALayout, BLayout, CLayout>::TransposeC>>;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
CodegenGemmPipeline<ALayout, BLayout, CLayout>, CodegenGemmPipeline<ALayout, BLayout, CLayout>,
GemmEpilogue<CLayout>>; GemmEpilogue<ALayout, BLayout, CLayout>>;
}; // namespace }; // namespace
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs) std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
...@@ -124,7 +118,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs, ...@@ -124,7 +118,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel: " << GroupedGemmKernel::GetName() << " with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl; << std::endl;
......
...@@ -202,7 +202,7 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -202,7 +202,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl; << std::endl;
} }
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
return pass; return pass;
......
set(TARGET_NAME tile_example_batched_transpose)
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL batched_transpose_example.cpp batched_transpose_api.cpp)
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(tile_example_batched_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS})
# Batched Transpose
This folder contains example for batched Transpose using ck_tile tile-programming implementation. Currently, it supports the batched transpose with NCHW to NHWC or NHWC to NCHW. So in this way from NCHW you could transpose to either NHWC or NWCH(two transposes). Now the transpose read with single data point. We would soon put it in vectorized transpose.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# Make the transpose executable
make tile_example_batched_transpose -j
```
This will result in an executable `build/bin/tile_example_batched_transpose`
## example
```
args:
-N input batch size (default:2)
-C input channel size. (default:16)
-H input height size. (default:1)
-W input width size. (default:16)
-v whether do CPU validation or not (default: 1)
-layout_in input tensor data layout - NCHW by default
-layout_out output tensor data layout - NHWC by default
-seed seed to be used, -1 means random every time (default:-1)
-k_name t to 1 will print kernel name (default:0)
```
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "batched_transpose_example.hpp"
#include <iostream>
template <typename ts_type,
ck_tile::index_t block_x,
ck_tile::index_t block_y,
ck_tile::index_t warp_x,
ck_tile::index_t warp_y,
ck_tile::index_t thread_x,
ck_tile::index_t thread_y>
float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s)
{
uint32_t dim_block_h = (a.height + block_y - 1) / block_y;
uint32_t dim_block_w = (a.width + block_x - 1) / block_x;
uint32_t dim_stride = a.height * a.width;
a.dim_stride = dim_stride;
a.dim_block_h = dim_block_h;
a.dim_block_w = dim_block_w;
using block_tile = ck_tile::sequence<block_x, block_y>;
using warp_tile = ck_tile::sequence<warp_x, warp_y>;
using thread_tile = ck_tile::sequence<thread_x, thread_y>;
using ts_problem =
ck_tile::BatchedTransposeProblem<ts_type, block_tile, warp_tile, thread_tile>;
using ts_pipeline = ck_tile::BatchedTransposePipeline<ts_problem>;
using kernel = ck_tile::BatchedTransposeKernel<ts_pipeline>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y
#define FOREACH_TRANSPOSE_PARAM(F) \
F(fp16, ck_tile::fp16_t, 16, 16, 8, 8, 1, 1) \
F(bf16, ck_tile::bf16_t, 16, 16, 8, 8, 1, 1) \
F(fp32, ck_tile::fp32_t, 16, 16, 8, 8, 1, 1) \
F(int8, ck_tile::int8_t, 16, 16, 8, 8, 1, 1)
// Macro that defines one static function per line
#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY) \
static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY( \
batched_transpose_kargs& a, ck_tile::stream_config& s) \
{ \
return batched_transpose_dispatch<REAL_TYPE, BX, BY, WX, WY, TX, TY>(a, s); \
}
FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN)
float batched_transpose(batched_transpose_trait t,
batched_transpose_kargs a,
ck_tile::stream_config s)
{
if(t.type == "fp16")
{
return transpose_fn_fp16_16_16_8_8_1_1(a, s);
}
else if(t.type == "bf16")
{
return transpose_fn_bf16_16_16_8_8_1_1(a, s);
}
else if(t.type == "fp32")
{
return transpose_fn_fp32_16_16_8_8_1_1(a, s);
}
else if(t.type == "int8")
{
return transpose_fn_int8_16_16_8_8_1_1(a, s);
}
return -1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment