Commit fcc2c867 authored by coderfeli's avatar coderfeli
Browse files

impl gemm2 swizzle

parent aecd6a38
...@@ -236,7 +236,7 @@ int main(int argc, char* argv[]) ...@@ -236,7 +236,7 @@ int main(int argc, char* argv[])
// const ck::index_t experts = 8; // const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1})); Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1})); Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size; max_token_id.mData[0] = valid_size;
......
...@@ -197,8 +197,11 @@ struct GridwiseMoeGemmScatter ...@@ -197,8 +197,11 @@ struct GridwiseMoeGemmScatter
__host__ static auto CalculateGridSize(index_t M, index_t N) __host__ static auto CalculateGridSize(index_t M, index_t N)
{ {
return std::make_tuple(math::integer_divide_ceil(N, NPerBlock), // return std::make_tuple(math::integer_divide_ceil(N, NPerBlock),
math::integer_divide_ceil(M, MPerBlock), // math::integer_divide_ceil(M, MPerBlock),
// 1);
return std::make_tuple(math::integer_divide_ceil(N, NPerBlock) * math::integer_divide_ceil(M, MPerBlock),
1,
1); 1);
} }
...@@ -1149,10 +1152,22 @@ struct GridwiseMoeGemmScatter ...@@ -1149,10 +1152,22 @@ struct GridwiseMoeGemmScatter
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock); c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
const index_t expert_block_id = blockIdx.x / problem.NBlock;
// const index_t b_block_id = blockIdx.x % problem.NBlock;
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]);
const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1];
const index_t expert_block_swizzle = expert_block_id / expert_swizzle;
const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle);
// const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
// const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
// const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
// const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment