Commit 69114f25 authored by letaoqin's avatar letaoqin
Browse files

output sacc

parent bb7c4112
......@@ -280,7 +280,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
block_m);
// output_matrix_2d(a_host, tokens, hidden_size);
// std::cout << sorted_token_ids_host << std::endl;
std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
std::cout << sorted_expert_ids_host << std::endl;
......
......@@ -156,14 +156,14 @@ struct FusedMoeGemmPipeline_General
// load g to register
auto g_dram_block = load_tile(g_global_to_dram_window);
#if 1
#if 0
{
constexpr auto a_spans = decltype(g_dram_block)::get_distributed_spans();
int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxn) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxn, idxk);
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
index_t idn_0 = idxn.impl_.at(0);
......@@ -208,6 +208,34 @@ struct FusedMoeGemmPipeline_General
block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block);
}
#if 1
{
constexpr auto a_spans = decltype(s_acc)::get_distributed_spans();
int counter = 0;
//a_spans[0] = 1;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) {
constexpr auto i_j_idx = make_tuple(idxn, idxn);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
index_t idn_0 = idxn.impl_.at(0);
index_t idn_1 = idxn.impl_.at(1);
printf("in A idn is %d , idn_0 is %d, idn_1 is %d, counter is %d, value is: "
"%f \n",
idm_0,
idn_0,
idn_1,
counter,
ck_tile::type_convert<float>(s_acc(i_j_idx)));
}
});
});
}
#endif
// move sacc to LDS
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
......
......@@ -173,26 +173,26 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{
// using WG = decltype(GetWarpGemm0<Problem>());
// using S_ = typename Problem::BlockShape;
// static_assert(S_::WarpPerBlock_N0==4);
// constexpr auto g_outer_dstr_enc = tile_distribution_encoding<
// sequence<S_::WarpPerBlock_M0>,
// tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>, sequence<S_::Repeat_K0>>,
// tuple<sequence<0, 1>>,
// tuple<sequence<0, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// constexpr auto g_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// g_outer_dstr_enc, typename WG::BWarpDstrEncoding{});
constexpr auto g_block_dstr_encode = tile_distribution_encoding<
sequence<1>,
tuple<sequence<1, 4, 32>, sequence<4, 2, 4>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<0, 0>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
using WG = decltype(GetWarpGemm0<Problem>());
using S_ = typename Problem::BlockShape;
static_assert(S_::WarpPerBlock_N0==4);
constexpr auto g_outer_dstr_enc = tile_distribution_encoding<
sequence<S_::WarpPerBlock_M0>,
tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>, sequence<S_::Repeat_K0>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto g_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
g_outer_dstr_enc, typename WG::BWarpDstrEncoding{});
// constexpr auto g_block_dstr_encode = tile_distribution_encoding<
// sequence<1>,
// tuple<sequence<1, 4, 32>, sequence<4, 2, 4>>,
// tuple<sequence<0, 1>, sequence<2, 1>>,
// tuple<sequence<0, 1>, sequence<1, 2>>,
// sequence<1, 2, 2>,
// sequence<0, 0, 2>>{};
return make_static_tile_distribution(g_block_dstr_encode);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment