Commit eaf8e616 authored by letaoqin's avatar letaoqin
Browse files

write a data to lds

parent 3b51749a
...@@ -59,6 +59,21 @@ auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, ...@@ -59,6 +59,21 @@ auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype,
return t; return t;
} }
template <typename IndexType>
void output_matrix_2d(ck_tile::HostTensor<IndexType>& data, int m,int n)
{
std::cout << std::endl;
for(int i = 0; i < m; i++)
{
std::cout << "Line " << i << "\t";
for(int j = 0; j < n; j++)
{
std::cout << ck_tile::type_convert<float>(data(i,j)) << "\t";
}
std::cout << std::endl;
}
}
template <typename IndexType> template <typename IndexType>
void topid_unique_gen( void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed) std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
...@@ -256,6 +271,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -256,6 +271,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// } // }
// std::cout << std::endl; // std::cout << std::endl;
// } // }
output_matrix_2d(a_host, tokens, hidden_size);
// std::cout << sorted_token_ids_host << std::endl; // std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl; // std::cout << num_sorted_tiles_host << std::endl;
// std::cout << sorted_expert_ids_host << std::endl; // std::cout << sorted_expert_ids_host << std::endl;
...@@ -277,6 +293,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -277,6 +293,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host); ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host);
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host); ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host); ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
o_buf.SetZero();
fused_moegemm_traits traits{prec_i, fused_moegemm_traits traits{prec_i,
prec_w, prec_w,
...@@ -363,6 +380,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -363,6 +380,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
pass &= ck_tile::check_err( pass &= ck_tile::check_err(
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
output_matrix_2d(o_dev, tokens, hidden_size);
} }
std::cout << std::flush << std::endl; std::cout << std::flush << std::endl;
......
...@@ -70,15 +70,15 @@ struct FusedMoeGemmPipeline_General ...@@ -70,15 +70,15 @@ struct FusedMoeGemmPipeline_General
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
// // matrix a or tokens smem // matrix a or tokens smem
// constexpr index_t smem_mat_a = constexpr index_t smem_mat_a =
// BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType); BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType);
// // shuffle C matrix // shuffle C matrix
// constexpr index_t smem_bridge = constexpr index_t smem_bridge =
// BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
// return max(smem_mat_a, smem_bridge); return max(smem_mat_a, smem_bridge);
return Policy::template GetSmemSize<Problem>(); //return Policy::template GetSmemSize<Problem>();
} }
// this is the thread-offset along row/col // this is the thread-offset along row/col
...@@ -105,35 +105,46 @@ struct FusedMoeGemmPipeline_General ...@@ -105,35 +105,46 @@ struct FusedMoeGemmPipeline_General
ignore = hidden_size; ignore = hidden_size;
ignore = intermediate_size; ignore = intermediate_size;
// CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem); CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
// auto a_lds_view = make_tensor_view<address_space_enum::lds>( auto a_lds_view = make_tensor_view<address_space_enum::lds>(
// smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()); smem_0, Policy::template MakeLdsStoreDesc_A<Problem>());
// auto a_lds_win = make_tile_window(a_lds_view, make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), {0, 0}); auto a_lds_win = make_tile_window(a_lds_view, make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), {0, 0});
auto a_global_to_dram_window = make_tile_window( auto a_global_to_dram_window = make_tile_window(
a_window_.get_bottom_tensor_view(), a_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
a_window_.get_window_origin(), a_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_A<Problem>()); Policy::template MakeGlobalTileDistribution_A<Problem>());
// auto o_win = make_tile_window_linear(
// o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
auto a_dram_block = load_tile(a_global_to_dram_window); auto a_dram_block = load_tile(a_global_to_dram_window);
//store_tile(a_lds_win, a_dram_block); store_tile(a_lds_win, a_dram_block);
ignore = a_dram_block; store_tile(o_window_, a_dram_block);
#if 0 #if 0
//check a matrix gather right or not //check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram)::get_distributed_spans(); constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
int counter = 0; int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) { sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk){ sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxm, idxk); constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0){ if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0)
counter = counter + 1; {
index_t idm_0 = idxm.impl_.at(0); counter = counter + 1;
index_t idn_0 = idxk.impl_.at(0); index_t idm_0 = idxm.impl_.at(0);
printf("in A idm is %d , idn_ is %d , counter is %d, value is: %f \n", idm_0, idn_0, counter, ck_tile::type_convert<float>(a_dram(i_j_idx))); index_t idn_0 = idxk.impl_.at(0);
} printf("in A idm is %d , idn_ is %d , counter is %d, value is: %f \n",
}); idm_0,
idn_0,
counter,
ck_tile::type_convert<float>(a_dram_block(i_j_idx)));
}
}); });
});
#endif #endif
} }
}; };
......
...@@ -232,53 +232,6 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -232,53 +232,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
} }
} }
#if 0
// Caution: this will require global memory pre-shuffled to follow the mfma layout
template <index_t NPerBlock,
index_t KPerBlock,
index_t WavesPerBlock_N,
index_t WavesPerBlock_K,
typename WarpGemm,
index_t Alignment,
FusedMoeGemmWeightPermuteEnum PermuteEnum =
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled()
{
static_assert(Alignment % WarpGemm::WarpGemmAttribute::Impl::kABKPerLane == 0);
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
constexpr index_t Nr_p = WavesPerBlock_N;
constexpr index_t Kr_p = WavesPerBlock_K;
constexpr index_t Nr_y = Nr / Nr_p;
constexpr index_t Kr_y = Kr / Kr_p;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>, // 0
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>,
// Nr_p, Kr_p Kw Nw
tuple<sequence<1, 2>, sequence<3, 3>>,
tuple<sequence<1, 1>, sequence<0, 1>>,
// Nr_y Kr_y Kv
sequence<1, 2, 3>,
sequence<0, 0, 2>>{});
// clang-format on
}
}
#endif
template <index_t WarpPerBlock_N_, template <index_t WarpPerBlock_N_,
index_t WarpPerBlock_K_, index_t WarpPerBlock_K_,
index_t Repeat_N_, index_t Repeat_N_,
...@@ -414,11 +367,11 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -414,11 +367,11 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0, lds_block_desc_0,
make_tuple( make_tuple(
make_pass_through_transform(number<NumIssues>{}), // make_pass_through_transform(),
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})), make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))), make_merge_transform(make_tuple(number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}), make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc_issues_warps_lanes; return lds_block_desc_issues_warps_lanes;
} }
...@@ -446,12 +399,13 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -446,12 +399,13 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0, lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}), make_tuple(
make_pass_through_transform(number<NumWarps>{}), //make_pass_through_transform(number<NumIssues>{}),
make_merge_transform(make_tuple( //make_pass_through_transform(number<NumWarps>{}),
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))), make_merge_transform(make_tuple(number<NumIssues>{},number<LaneGroups>{}, number<NumWarps>{})),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc_issues_warps_lanes; return lds_block_desc_issues_warps_lanes;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment