Commit a75f162b authored by coderfeli's avatar coderfeli
Browse files

add files

parent 59c05300
...@@ -33,20 +33,20 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -33,20 +33,20 @@ struct fmoe_ // traits, ugly name, only used for internal
using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>; using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>; using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>; using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
// S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
static constexpr ck_tile::index_t BI_ = static constexpr ck_tile::index_t BI_ =
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
using BlockTile_0 = ck_tile::sequence<BT_, BI_ / (GateOnly_ ? 1 : 2), BH_>; using BlockTile_0 = ck_tile::sequence<BT_, BI_ / (GateOnly_ ? 1 : 2), BH_>; //32, 512, 128
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>; // S<1, 4, 1>
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>; using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>; // S<16, 16, 32>
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>; using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>; // 32, 128, 512
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>; /// S<1, 4, 1>
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>; using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>; // S<16, 16, 32>
static constexpr ck_tile::index_t GateOnly = GateOnly_; static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_; static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
......
...@@ -285,7 +285,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -285,7 +285,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
else else
{ {
topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913); for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++) {
topk_ids_host.mData[i] = i % 4;
}
// topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
} }
// leave it here for future debug purpose // leave it here for future debug purpose
...@@ -442,6 +445,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -442,6 +445,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
topk, topk,
gate_only); gate_only);
sorted_token_ids_host.savetxt("sorted_token_ids_host.txt", "int");
sorted_expert_ids_host.savetxt("sorted_expert_ids_host.txt", "int");
num_sorted_tiles_host.savetxt("num_sorted_tiles_host.txt", "int");
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float"); // o_dev.savetxt("gpu-out.txt", "float");
auto [rtol, atol] = get_elimit<ADataType>(); auto [rtol, atol] = get_elimit<ADataType>();
......
...@@ -61,15 +61,16 @@ struct FusedMoeGemmShape ...@@ -61,15 +61,16 @@ struct FusedMoeGemmShape
// TODO: we don't support half warps aound to 1 warp here // TODO: we don't support half warps aound to 1 warp here
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{})); static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{}); // S<32, 512, 128>, S<1, 4, 1>, S<16, 16, 32>
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{}); static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{}); //32
static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{}); static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{}); //512
static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{}); static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{}); // 128
static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{}); static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{}); // 1
static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(number<2>{}); static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{}); // 4
static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{}); static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(number<2>{}); // 1
static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{}); static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{}); // 16
static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{}); static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{}); // 16
static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{}); // 32
static constexpr index_t ThreadPerBlock_M0 = Warp_M0 * WarpPerBlock_M0; static constexpr index_t ThreadPerBlock_M0 = Warp_M0 * WarpPerBlock_M0;
static constexpr index_t ThreadPerBlock_N0 = Warp_N0 * WarpPerBlock_N0; static constexpr index_t ThreadPerBlock_N0 = Warp_N0 * WarpPerBlock_N0;
...@@ -77,19 +78,19 @@ struct FusedMoeGemmShape ...@@ -77,19 +78,19 @@ struct FusedMoeGemmShape
static_assert(Block_M0 % ThreadPerBlock_M0 == 0); static_assert(Block_M0 % ThreadPerBlock_M0 == 0);
static_assert(Block_N0 % ThreadPerBlock_N0 == 0); static_assert(Block_N0 % ThreadPerBlock_N0 == 0);
static_assert(Block_K0 % ThreadPerBlock_K0 == 0); static_assert(Block_K0 % ThreadPerBlock_K0 == 0);
static constexpr index_t Repeat_M0 = Block_M0 / ThreadPerBlock_M0; static constexpr index_t Repeat_M0 = Block_M0 / ThreadPerBlock_M0; // 2
static constexpr index_t Repeat_N0 = Block_N0 / ThreadPerBlock_N0; static constexpr index_t Repeat_N0 = Block_N0 / ThreadPerBlock_N0; // 8
static constexpr index_t Repeat_K0 = Block_K0 / ThreadPerBlock_K0; static constexpr index_t Repeat_K0 = Block_K0 / ThreadPerBlock_K0; // 4
static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{}); static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{}); //32
static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{}); static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{}); //128
static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{}); static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{}); //512
static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{}); static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{}); // 1
static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{}); static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{}); // 4
static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{}); static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{}); // 1
static constexpr index_t Warp_M1 = WarpTile_1::at(number<0>{}); static constexpr index_t Warp_M1 = WarpTile_1::at(number<0>{}); // 16
static constexpr index_t Warp_N1 = WarpTile_1::at(number<1>{}); static constexpr index_t Warp_N1 = WarpTile_1::at(number<1>{}); // 16
static constexpr index_t Warp_K1 = WarpTile_1::at(number<2>{}); static constexpr index_t Warp_K1 = WarpTile_1::at(number<2>{}); // 32
static constexpr index_t ThreadPerBlock_M1 = Warp_M1 * WarpPerBlock_M1; static constexpr index_t ThreadPerBlock_M1 = Warp_M1 * WarpPerBlock_M1;
static constexpr index_t ThreadPerBlock_N1 = Warp_N1 * WarpPerBlock_N1; static constexpr index_t ThreadPerBlock_N1 = Warp_N1 * WarpPerBlock_N1;
...@@ -97,9 +98,9 @@ struct FusedMoeGemmShape ...@@ -97,9 +98,9 @@ struct FusedMoeGemmShape
static_assert(Block_M1 % ThreadPerBlock_M1 == 0); static_assert(Block_M1 % ThreadPerBlock_M1 == 0);
static_assert(Block_N1 % ThreadPerBlock_N1 == 0); static_assert(Block_N1 % ThreadPerBlock_N1 == 0);
static_assert(Block_K1 % ThreadPerBlock_K1 == 0); static_assert(Block_K1 % ThreadPerBlock_K1 == 0);
static constexpr index_t Repeat_M1 = Block_M1 / ThreadPerBlock_M1; static constexpr index_t Repeat_M1 = Block_M1 / ThreadPerBlock_M1; // 2
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1; static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1; // 2
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1; static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1; // 16
static constexpr index_t BlockSize = warpSize * NumWarps; static constexpr index_t BlockSize = warpSize * NumWarps;
...@@ -115,9 +116,9 @@ struct FusedMoeGemmShape ...@@ -115,9 +116,9 @@ struct FusedMoeGemmShape
static constexpr index_t Block_W0 = Warp_N0 * Warp_K0; static constexpr index_t Block_W0 = Warp_N0 * Warp_K0;
static constexpr index_t Block_Nr0 = Block_N0 / Warp_N0; static constexpr index_t Block_Nr0 = Block_N0 / Warp_N0;
static constexpr index_t Block_Kr0 = Block_K0 / Warp_K0; static constexpr index_t Block_Kr0 = Block_K0 / Warp_K0;
static constexpr index_t Block_W1 = Warp_N1 * Warp_K1; static constexpr index_t Block_W1 = Warp_N1 * Warp_K1; // 512
static constexpr index_t Block_Nr1 = Block_N1 / Warp_N1; static constexpr index_t Block_Nr1 = Block_N1 / Warp_N1; // 8
static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1; static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1; // 16
static_assert(Block_W0 == Block_W1); static_assert(Block_W0 == Block_W1);
// static_assert(Block_Nr0 == Block_Kr1); // static_assert(Block_Nr0 == Block_Kr1);
......
...@@ -199,6 +199,8 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -199,6 +199,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA; threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
}, },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
if (row_ids_a[0] >= kargs.num_tokens)
return;
auto a_res = auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr), make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType)); kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
...@@ -238,7 +240,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -238,7 +240,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>( const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr, d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1), make_tuple(nr_1, kr_1, BlockShape::Block_W1), // n/16, k/32, 512
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1), make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<kAlignmentD>{}, number<kAlignmentD>{},
number<1>{}); number<1>{});
...@@ -264,13 +266,13 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -264,13 +266,13 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto d_coords = [&]() { auto d_coords = [&]() {
constexpr index_t Nr_ = 2; constexpr index_t Nr_ = 2;
constexpr index_t Nw_ = 4; constexpr index_t Nw_ = 4;
constexpr index_t Kr0_ = 4; constexpr index_t Kr0_ = BlockShape::Block_Kr1 / Kr1_; //4
constexpr index_t Kr1_ = 4; constexpr index_t Kr1_ = 4;
constexpr index_t Kl_ = 4; constexpr index_t Kl_ = 4;
constexpr index_t Nl_ = 16; constexpr index_t Nl_ = 16;
constexpr index_t Kv_ = 8; constexpr index_t Kv_ = 8;
constexpr index_t W_ = Kl_ * Nl_ * Kv_; constexpr index_t W_ = Kl_ * Nl_ * Kv_; // 512
constexpr index_t num_offsets_ = Nr_ * Kr0_; constexpr index_t num_offsets_ = Nr_ * Kr0_; // 8
index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) * index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) *
shared_intermediate_size_1 * shared_intermediate_size_1 *
Nl_; // Kr0_ * Kr1_ * W_; Nl_; // Kr0_ * Kr1_ * W_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment