Commit bb7c4112 authored by letaoqin's avatar letaoqin
Browse files

debugging

parent 7881eff9
......@@ -73,6 +73,23 @@ void output_matrix_2d(ck_tile::HostTensor<IndexType>& data, int m, int n)
std::cout << std::endl;
}
}
template <typename IndexType>
void output_matrix_3d(ck_tile::HostTensor<IndexType>& data, int M, int N, int J)
{
std::cout << std::endl;
for(int m = 0; m < M; m++)
{
for(int n = 0; n < N; n++)
{
std::cout << "experts: " << m << " Line: " << n << "\t";
for(int j = 0; j < J; j++)
{
std::cout << ck_tile::type_convert<float>(data(m, n, j)) << "\t";
}
std::cout << std::endl;
}
}
}
template <typename IndexType>
void topid_unique_gen(
......@@ -265,7 +282,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// output_matrix_2d(a_host, tokens, hidden_size);
// std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
// std::cout << sorted_expert_ids_host << std::endl;
output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
std::cout << sorted_expert_ids_host << std::endl;
// std::cout << topk_weight_host << std::endl;
// std::cout << sorted_weight_host << std::endl;
......
......@@ -301,11 +301,10 @@ struct FusedMoeGemmGlKernel
// TODO: gtile using NSub to have less register pressure
const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
idx_n0 * kargs.hidden_size;
static_cast<long_index_t>(expert_id) * expert_stride_0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(BlockShape::Block_N0, kargs.hidden_size),
make_tuple(kargs.intermediate_size, kargs.hidden_size),
make_tuple(kargs.hidden_size, 1),
number<Pipeline::kAlignmentG>{},
number<1>{});
......@@ -313,7 +312,7 @@ struct FusedMoeGemmGlKernel
const auto g_window_ = make_tile_window(
g_view_,
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
{idx_n0, 0});
return g_window_;
}();
......
......@@ -98,6 +98,9 @@ struct FusedMoeGemmPipeline_General
index_t hidden_size,
index_t intermediate_size)
{
ignore = d_window_;
ignore = hidden_size;
ignore = intermediate_size;
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsBlockDesc_A<Problem>());
......@@ -126,10 +129,60 @@ struct FusedMoeGemmPipeline_General
// save tokens to lds
auto a_dram_block = load_tile(a_global_to_dram_window);
store_tile(a_lds_win, a_dram_block);
#if 0
{
// check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
index_t idk_0 = idxk.impl_.at(0);
printf("in A idm is %d , idk_ is %d , counter is %d, value is: %f \n",
idm_0,
idk_0,
counter,
ck_tile::type_convert<float>(a_dram_block(i_j_idx)));
}
});
});
}
#endif
// load g to register
auto g_dram_block = load_tile(g_global_to_dram_window);
#if 1
{
constexpr auto a_spans = decltype(g_dram_block)::get_distributed_spans();
int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxn) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxn, idxk);
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
index_t idn_0 = idxn.impl_.at(0);
index_t idk_0 = idxk.impl_.at(0);
index_t idk_1 = idxk.impl_.at(1);
printf("in A idn is %d , idk_0 is %d idk_1 is %d, counter is %d, value is: "
"%f \n",
idn_0,
idk_0,
idk_1,
counter,
ck_tile::type_convert<float>(g_dram_block(i_j_idx)));
}
});
});
}
#endif
clear_tile(s_acc); // initialize C
constexpr index_t kK0 = BlockShape::Block_K0;
const index_t k0_loops = ck_tile::integer_divide_ceil(intermediate_size, kK0);
......@@ -196,6 +249,10 @@ struct FusedMoeGemmPipeline_General
block_sync_lds();
move_tile_window(d_global_to_dram_window, {kN1, 0});
d = load_tile(d_global_to_dram_window);
// move out window and save data
auto o = cast_tile<ODataType>(o_acc);
store_tile(o_window_, o);
move_tile_window(o_window_, {kN1, 0});
iCounter1--;
}
......@@ -204,30 +261,7 @@ struct FusedMoeGemmPipeline_General
block_sync_lds();
gemm_1(o_acc, y, d);
}
auto o = cast_tile<ODataType>(o_acc);
store_tile(o_window_, o);
// store_tile(o_window_, a_dram_block);
#if 0
//check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
index_t idn_0 = idxk.impl_.at(0);
printf("in A idm is %d , idn_ is %d , counter is %d, value is: %f \n",
idm_0,
idn_0,
counter,
ck_tile::type_convert<float>(a_dram_block(i_j_idx)));
}
});
});
#endif
}
};
......
......@@ -173,18 +173,26 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{
using WG = decltype(GetWarpGemm0<Problem>());
using S_ = typename Problem::BlockShape;
constexpr auto g_outer_dstr_enc = tile_distribution_encoding<
sequence<>,
tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>, sequence<S_::Repeat_K0>>,
tuple<sequence<1>>,
tuple<sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto g_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
g_outer_dstr_enc, typename WG::BWarpDstrEncoding{});
// using WG = decltype(GetWarpGemm0<Problem>());
// using S_ = typename Problem::BlockShape;
// static_assert(S_::WarpPerBlock_N0==4);
// constexpr auto g_outer_dstr_enc = tile_distribution_encoding<
// sequence<S_::WarpPerBlock_M0>,
// tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>, sequence<S_::Repeat_K0>>,
// tuple<sequence<0, 1>>,
// tuple<sequence<0, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// constexpr auto g_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// g_outer_dstr_enc, typename WG::BWarpDstrEncoding{});
constexpr auto g_block_dstr_encode = tile_distribution_encoding<
sequence<1>,
tuple<sequence<1, 4, 32>, sequence<4, 2, 4>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<0, 0>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
return make_static_tile_distribution(g_block_dstr_encode);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment