Commit 1078d229 authored by coderfeli's avatar coderfeli
Browse files

add logics and debug

parent d4b8f1e3
......@@ -245,7 +245,7 @@ int main(int argc, char* argv[])
int tile_off = i % sorted_tile_size;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % batch) & ((tokenid / batch) << 24);
sorted_token_ids.mData[i] = (tokenid % batch) | ((tokenid / batch) << 24);
tokenid++;
}
else
......@@ -389,17 +389,20 @@ int main(int argc, char* argv[])
for(int m = 0; m < SORTED_SIZE; ++m)
{
const int fuse_t = sorted_token_ids(m);
const int fuse_t = sorted_token_ids.mData[m];
const int t = fuse_t & 0xffffff;
const int topk_id = (fuse_t & 0xff000000) >> 24;
printf("m %d fuset %d %d %d\n",m, fuse_t, t, topk_id);
if (t >= tokens)
{
continue;
}
const int topk_id = (fuse_t & 0xff000000) >> 24;
const int e = expert_ids(m / sorted_tile_size);
for(int n = 0; n < N; ++n)
{
cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(m, topk_id, n), d0_t_n(t, n), d1_e_n(e, n));
printf("m %d fuset %d %d %d %f %f\n",m, topk_id, t, n, e_t_n_host_result(t, topk_id, n), c_t_k_n(m, topk_id, n));
}
}
......
......@@ -535,6 +535,7 @@ struct GridwiseMoeGemmGather
struct Problem
{
__host__ __device__ Problem(index_t NumTokens_,
index_t TopK_,
index_t M_,
index_t N_,
index_t K_,
......
......@@ -74,6 +74,8 @@ struct ReferenceMoeGemm : public device::BaseOperator
AccDataType v_acc{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
if(m >= max_sorted_num)
return;
const int t = arg.sorted_token_ids_(m) & 0xffffff;
const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24;
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
......@@ -105,17 +107,16 @@ struct ReferenceMoeGemm : public device::BaseOperator
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
}
CDataType v_c{0};
arg.c_element_op_(v_c, v_acc);
arg.c_t_k_n_(t, topk_id, n) = v_c;
}
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.sorted_tile_size_, arg.c_t_k_n_.mDesc.GetLengths()[2])(
f_mk_kn_mn, arg.sorted_token_ids_.GetLengths()[0], arg.c_t_k_n_.mDesc.GetLengths()[2])(
std::thread::hardware_concurrency());
return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment