Commit 69114f25 authored by letaoqin's avatar letaoqin
Browse files

output sacc

parent bb7c4112
...@@ -280,7 +280,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -280,7 +280,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
block_m); block_m);
// output_matrix_2d(a_host, tokens, hidden_size); // output_matrix_2d(a_host, tokens, hidden_size);
// std::cout << sorted_token_ids_host << std::endl; std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl; // std::cout << num_sorted_tiles_host << std::endl;
output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size); output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
std::cout << sorted_expert_ids_host << std::endl; std::cout << sorted_expert_ids_host << std::endl;
......
...@@ -156,14 +156,14 @@ struct FusedMoeGemmPipeline_General ...@@ -156,14 +156,14 @@ struct FusedMoeGemmPipeline_General
// load g to register // load g to register
auto g_dram_block = load_tile(g_global_to_dram_window); auto g_dram_block = load_tile(g_global_to_dram_window);
#if 1 #if 0
{ {
constexpr auto a_spans = decltype(g_dram_block)::get_distributed_spans(); constexpr auto a_spans = decltype(g_dram_block)::get_distributed_spans();
int counter = 0; int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxn) { sweep_tile_span(a_spans[number<0>{}], [&](auto idxn) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) { sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxn, idxk); constexpr auto i_j_idx = make_tuple(idxn, idxk);
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{ {
counter = counter + 1; counter = counter + 1;
index_t idn_0 = idxn.impl_.at(0); index_t idn_0 = idxn.impl_.at(0);
...@@ -208,6 +208,34 @@ struct FusedMoeGemmPipeline_General ...@@ -208,6 +208,34 @@ struct FusedMoeGemmPipeline_General
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block); gemm_0(s_acc, a_lds_win, g_dram_block);
} }
#if 1
{
constexpr auto a_spans = decltype(s_acc)::get_distributed_spans();
int counter = 0;
//a_spans[0] = 1;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) {
constexpr auto i_j_idx = make_tuple(idxn, idxn);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
index_t idn_0 = idxn.impl_.at(0);
index_t idn_1 = idxn.impl_.at(1);
printf("in A idn is %d , idn_0 is %d, idn_1 is %d, counter is %d, value is: "
"%f \n",
idm_0,
idn_0,
idn_1,
counter,
ck_tile::type_convert<float>(s_acc(i_j_idx)));
}
});
});
}
#endif
// move sacc to LDS // move sacc to LDS
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>( auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>()); smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
......
...@@ -173,26 +173,26 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -173,26 +173,26 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{ {
// using WG = decltype(GetWarpGemm0<Problem>()); using WG = decltype(GetWarpGemm0<Problem>());
// using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
// static_assert(S_::WarpPerBlock_N0==4); static_assert(S_::WarpPerBlock_N0==4);
// constexpr auto g_outer_dstr_enc = tile_distribution_encoding< constexpr auto g_outer_dstr_enc = tile_distribution_encoding<
// sequence<S_::WarpPerBlock_M0>, sequence<S_::WarpPerBlock_M0>,
// tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>, sequence<S_::Repeat_K0>>, tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>, sequence<S_::Repeat_K0>>,
// tuple<sequence<0, 1>>, tuple<sequence<0, 1>>,
// tuple<sequence<0, 1>>, tuple<sequence<0, 1>>,
// sequence<1, 2>, sequence<1, 2>,
// sequence<0, 0>>{}; sequence<0, 0>>{};
// constexpr auto g_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto g_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// g_outer_dstr_enc, typename WG::BWarpDstrEncoding{}); g_outer_dstr_enc, typename WG::BWarpDstrEncoding{});
constexpr auto g_block_dstr_encode = tile_distribution_encoding< // constexpr auto g_block_dstr_encode = tile_distribution_encoding<
sequence<1>, // sequence<1>,
tuple<sequence<1, 4, 32>, sequence<4, 2, 4>>, // tuple<sequence<1, 4, 32>, sequence<4, 2, 4>>,
tuple<sequence<0, 1>, sequence<2, 1>>, // tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<0, 0>>, // tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2, 2>, // sequence<1, 2, 2>,
sequence<0, 0, 2>>{}; // sequence<0, 0, 2>>{};
return make_static_tile_distribution(g_block_dstr_encode); return make_static_tile_distribution(g_block_dstr_encode);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment