Commit ef8e3620 authored by letaoqin's avatar letaoqin
Browse files

gather and scatter right

parent eaf8e616
...@@ -60,7 +60,7 @@ auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, ...@@ -60,7 +60,7 @@ auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype,
} }
template <typename IndexType> template <typename IndexType>
void output_matrix_2d(ck_tile::HostTensor<IndexType>& data, int m,int n) void output_matrix_2d(ck_tile::HostTensor<IndexType>& data, int m, int n)
{ {
std::cout << std::endl; std::cout << std::endl;
for(int i = 0; i < m; i++) for(int i = 0; i < m; i++)
...@@ -68,7 +68,7 @@ void output_matrix_2d(ck_tile::HostTensor<IndexType>& data, int m,int n) ...@@ -68,7 +68,7 @@ void output_matrix_2d(ck_tile::HostTensor<IndexType>& data, int m,int n)
std::cout << "Line " << i << "\t"; std::cout << "Line " << i << "\t";
for(int j = 0; j < n; j++) for(int j = 0; j < n; j++)
{ {
std::cout << ck_tile::type_convert<float>(data(i,j)) << "\t"; std::cout << ck_tile::type_convert<float>(data(i, j)) << "\t";
} }
std::cout << std::endl; std::cout << std::endl;
} }
...@@ -261,17 +261,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -261,17 +261,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m);
// std::cout << std::endl;
// for(int i = 0; i < tokens; i++) // output_matrix_2d(a_host, tokens, hidden_size);
// {
// std::cout << "Line " << i << "\t";
// for(int j = 0; j < hidden_size; j++)
// {
// std::cout << ck_tile::type_convert<float>(a_host(i,j)) << "\t";
// }
// std::cout << std::endl;
// }
output_matrix_2d(a_host, tokens, hidden_size);
// std::cout << sorted_token_ids_host << std::endl; // std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl; // std::cout << num_sorted_tiles_host << std::endl;
// std::cout << sorted_expert_ids_host << std::endl; // std::cout << sorted_expert_ids_host << std::endl;
...@@ -381,7 +372,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -381,7 +372,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
output_matrix_2d(o_dev, tokens, hidden_size); // std::cout << std::endl;
// int count = 0;
// for(int i = 0; i < tokens; i++)
// {
// std::cout << "Line " << i << "\t";
// for(int j = 0; j < hidden_size; j++)
// {
// std::cout << ck_tile::type_convert<float>(o_dev(count++)) << "\t";
// }
// std::cout << std::endl;
// }
} }
std::cout << std::flush << std::endl; std::cout << std::flush << std::endl;
......
...@@ -80,7 +80,7 @@ struct indexing_adaptor ...@@ -80,7 +80,7 @@ struct indexing_adaptor
pre_up_index_ = idx_up[number<0>{}]; pre_up_index_ = idx_up[number<0>{}];
pre_low_index_ = idx_low(number<0>{}); pre_low_index_ = idx_low(number<0>{});
#if 0 #if 0
if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{ {
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{})); printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
} }
...@@ -105,7 +105,7 @@ struct indexing_adaptor ...@@ -105,7 +105,7 @@ struct indexing_adaptor
pre_up_index_ = up_index; pre_up_index_ = up_index;
pre_low_index_ = low_index; pre_low_index_ = low_index;
#if 0 #if 0
if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{ {
printf("\n index form %d to %d, diff from %d to %d \n", printf("\n index form %d to %d, diff from %d to %d \n",
up_index, up_index,
......
...@@ -78,7 +78,7 @@ struct FusedMoeGemmPipeline_General ...@@ -78,7 +78,7 @@ struct FusedMoeGemmPipeline_General
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_mat_a, smem_bridge); return max(smem_mat_a, smem_bridge);
//return Policy::template GetSmemSize<Problem>(); // return Policy::template GetSmemSize<Problem>();
} }
// this is the thread-offset along row/col // this is the thread-offset along row/col
...@@ -108,7 +108,10 @@ struct FusedMoeGemmPipeline_General ...@@ -108,7 +108,10 @@ struct FusedMoeGemmPipeline_General
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem); CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
auto a_lds_view = make_tensor_view<address_space_enum::lds>( auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()); smem_0, Policy::template MakeLdsStoreDesc_A<Problem>());
auto a_lds_win = make_tile_window(a_lds_view, make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), {0, 0}); auto a_lds_win = make_tile_window(
a_lds_view,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
auto a_global_to_dram_window = make_tile_window( auto a_global_to_dram_window = make_tile_window(
a_window_.get_bottom_tensor_view(), a_window_.get_bottom_tensor_view(),
...@@ -116,10 +119,6 @@ struct FusedMoeGemmPipeline_General ...@@ -116,10 +119,6 @@ struct FusedMoeGemmPipeline_General
a_window_.get_window_origin(), a_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_A<Problem>()); Policy::template MakeGlobalTileDistribution_A<Problem>());
// auto o_win = make_tile_window_linear(
// o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
auto a_dram_block = load_tile(a_global_to_dram_window); auto a_dram_block = load_tile(a_global_to_dram_window);
store_tile(a_lds_win, a_dram_block); store_tile(a_lds_win, a_dram_block);
...@@ -132,7 +131,7 @@ struct FusedMoeGemmPipeline_General ...@@ -132,7 +131,7 @@ struct FusedMoeGemmPipeline_General
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) { sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) { sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxm, idxk); constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{ {
counter = counter + 1; counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0); index_t idm_0 = idxm.impl_.at(0);
......
...@@ -369,7 +369,8 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -369,7 +369,8 @@ struct FusedMoeGemmPipelineGeneralPolicy
make_tuple( make_tuple(
// make_pass_through_transform(), // make_pass_through_transform(),
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})), make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
make_merge_transform(make_tuple(number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))), make_merge_transform(make_tuple(
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -400,9 +401,10 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -400,9 +401,10 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0, lds_block_desc_0,
make_tuple( make_tuple(
//make_pass_through_transform(number<NumIssues>{}), // make_pass_through_transform(number<NumIssues>{}),
//make_pass_through_transform(number<NumWarps>{}), // make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(number<NumIssues>{},number<LaneGroups>{}, number<NumWarps>{})), make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))), make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}), make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment