Commit e15c6f2d authored by coderfeli's avatar coderfeli
Browse files

skip out of bound rowid

parent 20ac5ef9
...@@ -346,10 +346,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -346,10 +346,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
else else
{ {
// for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++) { for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++) {
// topk_ids_host.mData[i] = 0; topk_ids_host.mData[i] = i % 4;
// } }
topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913); // topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
} }
// leave it here for future debug purpose // leave it here for future debug purpose
......
...@@ -240,7 +240,7 @@ struct FusedMoeGemmKernel ...@@ -240,7 +240,7 @@ struct FusedMoeGemmKernel
{ {
if constexpr(UseUK) if constexpr(UseUK)
{ {
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()]; __shared__ CK_TILE_LDS_ADDR char smem[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane( IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr)); *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
...@@ -275,7 +275,7 @@ struct FusedMoeGemmKernel ...@@ -275,7 +275,7 @@ struct FusedMoeGemmKernel
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size; index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size; index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()]; __shared__ CK_TILE_LDS_ADDR char smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index // note this is in unit of tile, need multiple tile size to get the index
const auto [sorted_tile_id, intermediate_tile_id] = const auto [sorted_tile_id, intermediate_tile_id] =
......
...@@ -70,13 +70,21 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -70,13 +70,21 @@ struct FusedMoeGemmPipeline_FlatmmUk
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
return 32768;
// constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize(); // constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
// constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize(); // constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
// constexpr index_t smem_bridge = // constexpr index_t smem_bridge =
// BlockShape::Block_M0 * BlockShape::Block_N0; // BlockShape::Block_M0 * BlockShape::Block_N0;
return 32768;//max(smem_0, max(smem_1, smem_bridge)); // return 32768;//max(smem_0, max(smem_1, smem_bridge));
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
// return max(smem_0, max(smem_1, smem_bridge));
return max(smem_0 + smem_1, smem_bridge);
} }
// this is the thread-offset along row/col // this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetACoord() CK_TILE_HOST_DEVICE static auto GetACoord()
{ {
...@@ -199,7 +207,10 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -199,7 +207,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA; threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
}, },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
// if(threadIdx.x==0)
// printf("row id %d\n", row_ids_a[0]);
if (row_ids_a.at(0) >= kargs.num_tokens)
return;
auto a_res = auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr), make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType)); kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment