"vscode:/vscode.git/clone" did not exist on "4dab86fecc868bd552bae75855748b7aee7c4d9a"
Commit 08e44540 authored by coderfeli's avatar coderfeli
Browse files

debugs

parent a75f162b
......@@ -25,7 +25,8 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
// using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
......@@ -37,7 +38,8 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0)
{
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
// using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
r = fused_moegemm_<t_>(s, a);
}
// clang-format on
......
......@@ -33,7 +33,7 @@ struct fmoe_ // traits, ugly name, only used for internal
using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
// S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>
// S<32, 1024|512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
static constexpr ck_tile::index_t BI_ =
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
......@@ -44,7 +44,7 @@ struct fmoe_ // traits, ugly name, only used for internal
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>; // S<1, 4, 1>
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>; // S<16, 16, 32>
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>; // 32, 128, 512
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>; // 32, 128, 512|256
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>; /// S<1, 4, 1>
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>; // S<16, 16, 32>
......
......@@ -10,8 +10,13 @@
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -13,5 +13,8 @@ template float fused_moegemm_<
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -59,6 +59,42 @@ auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype,
}
return t;
}
template <typename T>
auto shuffle_moe_weight_gateup(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
{
assert(t.get_lengths().size() == 3);
int b_ = t.get_lengths()[0];
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[2];
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 16, 2, 8});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({b_, 2 , n_ / 512, 16 , 16, k_ / 32, 4, 8});
// ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
std::copy(t.begin(), t.end(), t_view.begin());
// return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
return ck_tile::reference_permute(t_view, {0, 2, 1, 3, 5, 6, 4, 7});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 32, 2, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 64, 4, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
return t;
}
template <typename IndexType>
void topid_unique_gen(
......
......@@ -427,10 +427,10 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
// [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
// [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
// [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
// [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
......
......@@ -63,7 +63,7 @@ struct FusedMoeGemmShape
// S<32, 512, 128>, S<1, 4, 1>, S<16, 16, 32>
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{}); //32
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{}); //512
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{}); //256
static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{}); // 128
static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{}); // 1
static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{}); // 4
......@@ -73,18 +73,18 @@ struct FusedMoeGemmShape
static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{}); // 32
static constexpr index_t ThreadPerBlock_M0 = Warp_M0 * WarpPerBlock_M0;
static constexpr index_t ThreadPerBlock_N0 = Warp_N0 * WarpPerBlock_N0;
static constexpr index_t ThreadPerBlock_N0 = Warp_N0 * WarpPerBlock_N0; // 64
static constexpr index_t ThreadPerBlock_K0 = Warp_K0 * WarpPerBlock_K0;
static_assert(Block_M0 % ThreadPerBlock_M0 == 0);
static_assert(Block_N0 % ThreadPerBlock_N0 == 0);
static_assert(Block_K0 % ThreadPerBlock_K0 == 0);
static constexpr index_t Repeat_M0 = Block_M0 / ThreadPerBlock_M0; // 2
static constexpr index_t Repeat_N0 = Block_N0 / ThreadPerBlock_N0; // 8
static constexpr index_t Repeat_N0 = Block_N0 / ThreadPerBlock_N0; // 4
static constexpr index_t Repeat_K0 = Block_K0 / ThreadPerBlock_K0; // 4
static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{}); //32
static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{}); //128
static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{}); //512
static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{}); //256
static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{}); // 1
static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{}); // 4
static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{}); // 1
......@@ -100,7 +100,7 @@ struct FusedMoeGemmShape
static_assert(Block_K1 % ThreadPerBlock_K1 == 0);
static constexpr index_t Repeat_M1 = Block_M1 / ThreadPerBlock_M1; // 2
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1; // 2
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1; // 16
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1; // 8
static constexpr index_t BlockSize = warpSize * NumWarps;
......@@ -118,7 +118,7 @@ struct FusedMoeGemmShape
static constexpr index_t Block_Kr0 = Block_K0 / Warp_K0;
static constexpr index_t Block_W1 = Warp_N1 * Warp_K1; // 512
static constexpr index_t Block_Nr1 = Block_N1 / Warp_N1; // 8
static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1; // 16
static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1; // 8
static_assert(Block_W0 == Block_W1);
// static_assert(Block_Nr0 == Block_Kr1);
......
......@@ -243,7 +243,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{
constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
// constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
using S_ = typename Problem::BlockShape;
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
......@@ -251,7 +251,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// number<S_::Repeat_N0>{}.eee();
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0,
S_::WarpPerBlock_K0,
S_::Repeat_N0, /// hidden_radio_0,
S_::Repeat_N0 * hidden_radio_0,
S_::Repeat_K0,
get_warp_size(),
GetAlignment_G<Problem>()>();
......@@ -803,7 +803,21 @@ struct FusedMoeGemmPipelineFlatmmPolicy
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return Flatmm_32x512x128_1x4x1_16x16x32_FP16{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return Flatmm_32x256x128_1x4x1_16x16x32_BF16{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return Flatmm_32x256x128_1x4x1_16x16x32_FP16{};
}
}
template <typename Problem>
......@@ -851,6 +865,26 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 256 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == true)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 256 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == true)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl{};
}
}
};
} // namespace ck_tile
......@@ -73,7 +73,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * (IsGateOnly ? 1 : 2);
BlockShape::Block_M0 * BlockShape::Block_N0;
return max(smem_0, max(smem_1, smem_bridge));
}
......@@ -168,7 +168,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
index_t intermediate_tile_id)
{
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size / hidden_radio_0;
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size;
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0;
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
......@@ -178,13 +178,13 @@ struct FusedMoeGemmPipeline_FlatmmUk
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size * hidden_radio_0;
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
// nr*kr*w
index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane(
intermediate_tile_id *
BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W)
BlockShape::Block_Nr0 * hidden_radio_0); // intermediate_tile_id * Block_N / (N in W)
index_t interm_idx_kr1 = __builtin_amdgcn_readfirstlane(
intermediate_tile_id *
......@@ -218,7 +218,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto g_window_ = make_tile_window_linear_raw(
g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
make_tuple(number<BlockShape::Block_Nr0 * hidden_radio_0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0},
......@@ -324,7 +324,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
BlockShape::Block_W0); // tile offset for B matrix each unroll
// fast GeLu
if constexpr(std::is_same_v<typename Problem::GateActivation,
......@@ -351,32 +351,32 @@ struct FusedMoeGemmPipeline_FlatmmUk
block_sync_lds();
// up
if(!IsGateOnly)
{
// up ptr. add hafl expoert_stride_0 as offset.
auto u_win = gu_win_gen(shared_intermediate_size_0 * kargs.hidden_size);
auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto u_coords =
generate_tuple([&](auto i) { return u_win.cached_coords_[i].get_offset(); },
number<decltype(u_win)::NumAccess_NonLinear>{});
// reuse UK0
auto uk_0_u = Policy::template GetUK_0<Problem>();
auto acc_0_u = uk_0_u(a_res,
a_coords,
u_res,
u_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
// elementwise mul gate*up.
sweep_tile(
y_pre,
[&](auto idx0) { y_pre(idx0) = y_pre(idx0) * acc_0_u(idx0); },
sequence<1, 1>{});
block_sync_lds();
}
// if(!IsGateOnly)
// {
// // up ptr. add hafl expoert_stride_0 as offset.
// auto u_win = gu_win_gen(shared_intermediate_size_0 * kargs.hidden_size);
// auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
// auto u_coords =
// generate_tuple([&](auto i) { return u_win.cached_coords_[i].get_offset(); },
// number<decltype(u_win)::NumAccess_NonLinear>{});
// // reuse UK0
// auto uk_0_u = Policy::template GetUK_0<Problem>();
// auto acc_0_u = uk_0_u(a_res,
// a_coords,
// u_res,
// u_coords,
// smem,
// kargs.hidden_size,
// BlockShape::Block_K0, // tile offset for B matrix each unroll
// BlockShape::Block_Kr0 *
// BlockShape::Block_W0); // tile offset for B matrix each unroll
// // elementwise mul gate*up.
// sweep_tile(
// y_pre,
// [&](auto idx0) { y_pre(idx0) = y_pre(idx0) * acc_0_u(idx0); },
// sequence<1, 1>{});
// block_sync_lds();
// }
store_tile(bridge_sst_win, cast_tile<YDataType>(y_pre));
block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment