Commit 1078d229 authored by coderfeli's avatar coderfeli
Browse files

add logics and debug

parent d4b8f1e3
...@@ -245,7 +245,7 @@ int main(int argc, char* argv[]) ...@@ -245,7 +245,7 @@ int main(int argc, char* argv[])
int tile_off = i % sorted_tile_size; int tile_off = i % sorted_tile_size;
if(tile_off < token_per_tile) if(tile_off < token_per_tile)
{ {
sorted_token_ids.mData[i] = (tokenid % batch) & ((tokenid / batch) << 24); sorted_token_ids.mData[i] = (tokenid % batch) | ((tokenid / batch) << 24);
tokenid++; tokenid++;
} }
else else
...@@ -389,17 +389,20 @@ int main(int argc, char* argv[]) ...@@ -389,17 +389,20 @@ int main(int argc, char* argv[])
for(int m = 0; m < SORTED_SIZE; ++m) for(int m = 0; m < SORTED_SIZE; ++m)
{ {
const int fuse_t = sorted_token_ids(m); const int fuse_t = sorted_token_ids.mData[m];
const int t = fuse_t & 0xffffff; const int t = fuse_t & 0xffffff;
const int topk_id = (fuse_t & 0xff000000) >> 24;
printf("m %d fuset %d %d %d\n",m, fuse_t, t, topk_id);
if (t >= tokens) if (t >= tokens)
{ {
continue; continue;
} }
const int topk_id = (fuse_t & 0xff000000) >> 24;
const int e = expert_ids(m / sorted_tile_size); const int e = expert_ids(m / sorted_tile_size);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(m, topk_id, n), d0_t_n(t, n), d1_e_n(e, n)); cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(m, topk_id, n), d0_t_n(t, n), d1_e_n(e, n));
printf("m %d fuset %d %d %d %f %f\n",m, topk_id, t, n, e_t_n_host_result(t, topk_id, n), c_t_k_n(m, topk_id, n));
} }
} }
......
...@@ -535,6 +535,7 @@ struct GridwiseMoeGemmGather ...@@ -535,6 +535,7 @@ struct GridwiseMoeGemmGather
struct Problem struct Problem
{ {
__host__ __device__ Problem(index_t NumTokens_, __host__ __device__ Problem(index_t NumTokens_,
index_t TopK_,
index_t M_, index_t M_,
index_t N_, index_t N_,
index_t K_, index_t K_,
......
...@@ -74,6 +74,8 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -74,6 +74,8 @@ struct ReferenceMoeGemm : public device::BaseOperator
AccDataType v_acc{0}; AccDataType v_acc{0};
ComputeTypeA v_a{0}; ComputeTypeA v_a{0};
ComputeTypeB v_b{0}; ComputeTypeB v_b{0};
if(m >= max_sorted_num)
return;
const int t = arg.sorted_token_ids_(m) & 0xffffff; const int t = arg.sorted_token_ids_(m) & 0xffffff;
const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24; const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24;
const int e = arg.expert_ids_(m / arg.sorted_tile_size_); const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
...@@ -105,17 +107,16 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -105,17 +107,16 @@ struct ReferenceMoeGemm : public device::BaseOperator
v_acc += v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
} CDataType v_c{0};
CDataType v_c{0};
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
arg.c_t_k_n_(t, topk_id, n) = v_c; arg.c_t_k_n_(t, topk_id, n) = v_c;
}
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
f_mk_kn_mn, arg.sorted_tile_size_, arg.c_t_k_n_.mDesc.GetLengths()[2])( f_mk_kn_mn, arg.sorted_token_ids_.GetLengths()[0], arg.c_t_k_n_.mDesc.GetLengths()[2])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment