Commit e15c6f2d authored by coderfeli's avatar coderfeli
Browse files

skip out of bound rowid

parent 20ac5ef9
......@@ -346,10 +346,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else
{
// for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++) {
// topk_ids_host.mData[i] = 0;
// }
topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++) {
topk_ids_host.mData[i] = i % 4;
}
// topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
}
// leave it here for future debug purpose
......
......@@ -240,7 +240,7 @@ struct FusedMoeGemmKernel
{
if constexpr(UseUK)
{
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
__shared__ CK_TILE_LDS_ADDR char smem[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
......@@ -275,7 +275,7 @@ struct FusedMoeGemmKernel
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
__shared__ CK_TILE_LDS_ADDR char smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index
const auto [sorted_tile_id, intermediate_tile_id] =
......
......@@ -70,13 +70,21 @@ struct FusedMoeGemmPipeline_FlatmmUk
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return 32768;
// constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
// constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
// constexpr index_t smem_bridge =
// BlockShape::Block_M0 * BlockShape::Block_N0;
return 32768;//max(smem_0, max(smem_1, smem_bridge));
// return 32768;//max(smem_0, max(smem_1, smem_bridge));
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
// return max(smem_0, max(smem_1, smem_bridge));
return max(smem_0 + smem_1, smem_bridge);
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetACoord()
{
......@@ -199,7 +207,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
},
number<row_ids_a.size()>{});
// if(threadIdx.x==0)
// printf("row id %d\n", row_ids_a[0]);
if (row_ids_a.at(0) >= kargs.num_tokens)
return;
auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment