"vscode:/vscode.git/clone" did not exist on "57a79819b052ed627ea2e8497224fdf7475bb027"
Commit 7bc31426 authored by carlushuang's avatar carlushuang
Browse files

fix mock token id

parent c35bb816
...@@ -85,6 +85,24 @@ void reference_fused_moe( ...@@ -85,6 +85,24 @@ void reference_fused_moe(
ck_tile::index_t intermediate_size_0 = intermediate_size; ck_tile::index_t intermediate_size_0 = intermediate_size;
ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2); ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2);
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
// assert();
auto f = [&](auto i_flatten) {
ck_tile::index_t i_tile = i_flatten / block_m;
if(i_tile >= num_sorted_tiles)
return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
ck_tile::index_t i_topk = i_token >> 24;
i_token &= 0xffffff;
if(i_token >= tokens)
return;
(void)token_ids_host;
#else
// TODO: better remove this in the future, or modify the token_id value // TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) { auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
for(ck_tile::index_t i_ = 0; i_ < topk; i_++) for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
...@@ -95,20 +113,11 @@ void reference_fused_moe( ...@@ -95,20 +113,11 @@ void reference_fused_moe(
throw std::runtime_error("not correct token/expert pair\n"); throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!! return -1; // TODO: not correct!!
}; };
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
// assert();
auto f = [&](auto i_flatten) {
ck_tile::index_t i_tile = i_flatten / block_m;
if(i_tile >= num_sorted_tiles)
return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten]; ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
if(i_token >= tokens) if(i_token >= tokens)
return; return;
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
#endif
auto weight = sorted_weight_host.mData[i_flatten]; auto weight = sorted_weight_host.mData[i_flatten];
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0}); ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
......
...@@ -299,6 +299,9 @@ struct FusedMoeGemmKernel ...@@ -299,6 +299,9 @@ struct FusedMoeGemmKernel
index_t token_id = index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id]; reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
token_id &= 0xffffff;
#endif
auto topk_weight = reinterpret_cast<const TopkWeightDataType*>( auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
kargs.sorted_weight_ptr)[sorted_token_id]; kargs.sorted_weight_ptr)[sorted_token_id];
......
...@@ -125,6 +125,9 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -125,6 +125,9 @@ struct FusedMoeGemmPipeline_FlatmmUk
array<index_t, n_size> row_ids; array<index_t, n_size> row_ids;
static_for<0, n_size, 1>{}([&](auto i) { static_for<0, n_size, 1>{}([&](auto i) {
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans; row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
row_ids.at(i) &= 0xffffff;
#endif
}); });
return row_ids; return row_ids;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment