Commit 45d1c52e authored by coderfeli's avatar coderfeli
Browse files

hotfix moegemm2 nswizzle

parent bca3f14c
......@@ -169,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, false, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
......
......@@ -1158,13 +1158,22 @@ struct GridwiseMoeGemm
const index_t expert_block_id = blockIdx.x / problem.NBlock;
// const index_t b_block_id = blockIdx.x % problem.NBlock;
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
const auto block_mn = [&]() -> std::pair<int, int> {
if constexpr (NSwizzle)
{
const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]);
const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1];
const index_t expert_block_swizzle = expert_block_id / expert_swizzle;
const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle);
const index_t nid = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8);
const index_t mid = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle);
return {nid, mid};
} else {
return {blockIdx.x, blockIdx.y};
}
}();
const index_t block_n_id = block_mn.first;
const index_t block_m_id = block_mn.second;
// if (threadIdx.x==0) {
// printf("bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d\n", blockIdx.x, expert_id, expert_swizzle, expert_block_swizzle, b_block_id_swizzle, block_m_id, block_n_id);
// }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment