Commit 40182e94 authored by letaoqin's avatar letaoqin
Browse files

change w read

parent cf01f064
......@@ -129,7 +129,7 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser;
arg_parser.insert("t", "128", "num input tokens")
.insert("e", "32", "num of experts")
.insert("k", "2", "topk")
.insert("k", "5", "topk")
.insert("h", "8192", "hidden_size of this model")
.insert("i", "8192", "intermediate_size between 2 gemms of FFN")
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
......@@ -285,6 +285,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed, true}(sy_host);
// ck_tile::FillConstant<TopkWeightDataType>{0.1}(topk_weight_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{0.0f, 1.0f, seed, true}(
topk_weight_host);
}
......@@ -308,6 +309,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed, true}(sy_host);
// ck_tile::FillConstant<TopkWeightDataType>{0.5}(topk_weight_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{0.0f, 1.0f, seed, true}(
topk_weight_host);
}
......@@ -397,11 +399,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
// output_matrix_2d(a_host, tokens, hidden_size);
std::cout << sorted_token_ids_host << std::endl;
std::cout << num_sorted_tiles_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
// output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
// output_matrix_3d(d_host, experts, hidden_size, shared_intermediate_size_1);
std::cout << sorted_expert_ids_host << std::endl;
// std::cout << topk_weight_host << std::endl;
// std::cout << sorted_expert_ids_host << std::endl;
std::cout << topk_weight_host << std::endl;
std::cout << sorted_weight_host << std::endl;
// done, preparing GPU buffer
......
......@@ -171,6 +171,7 @@ void reference_fused_moe(
// printf("in:%d, %f\t", i_n, acc);
acc_1(0, i_n) = acc * weight; // multiple weight here
}
(void)weight;
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
......
......@@ -170,6 +170,16 @@ struct FusedMoeGemmPipeline_General
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
constexpr auto w_dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
s_acc.get_tile_distribution().get_static_tile_distribution_encoding(),
sequence<1>{}));
auto w_global_to_dram_window = make_tile_window(w_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}),
w_window_.get_window_origin(),
w_dstr);
auto w = load_tile(w_global_to_dram_window);
auto a_dram_block = load_tile(a_global_to_dram_window);
auto g_dram_block = load_tile(g_global_to_dram_window);
// block_sync_load_raw();
......@@ -250,16 +260,6 @@ struct FusedMoeGemmPipeline_General
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
auto o_acc = OaccBlockTileType{};
constexpr auto w_dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
s_acc.get_tile_distribution().get_static_tile_distribution_encoding(),
sequence<1>{}));
auto w_global_to_dram_window = make_tile_window(w_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}),
w_window_.get_window_origin(),
w_dstr);
auto w = load_tile(w_global_to_dram_window);
float weight = type_convert<float>(w.get_thread_buffer()[0]);
#if 0
constexpr index_t w_buffer_size = decltype(w)::get_thread_buffer_size();
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
......@@ -294,7 +294,7 @@ struct FusedMoeGemmPipeline_General
// d data
auto d_global_to_dram_window = make_tile_window(
d_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}),
d_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_D<Problem>());
auto d = load_tile(d_global_to_dram_window);
......@@ -339,6 +339,8 @@ struct FusedMoeGemmPipeline_General
}
}
};
float weight = type_convert<float>(w.get_thread_buffer()[0]);
constexpr index_t kN1 = BlockShape::Block_N1;
const index_t n1_loops = ck_tile::integer_divide_ceil(hidden_size, kN1);
index_t iCounter1 = n1_loops - 1;
......@@ -382,6 +384,7 @@ struct FusedMoeGemmPipeline_General
#endif
}
// store_tile(o_window_, a_dram_block);
ignore = weight;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment