Commit 96047cab authored by coderfeli's avatar coderfeli
Browse files

impl e swizzel

parent 56cc306d
...@@ -191,11 +191,11 @@ int main(int argc, char* argv[]) ...@@ -191,11 +191,11 @@ int main(int argc, char* argv[])
// experts = 8 // experts = 8
// per expert: // per expert:
// GEMM shape // GEMM shape
ck::index_t N = 6144; ck::index_t N = 14336 * 2;
ck::index_t K = 8192; ck::index_t K = 4096;
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 8; ck::index_t valid_tile_num = 13;
ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t tokens = 64; ck::index_t tokens = 64;
...@@ -243,10 +243,11 @@ int main(int argc, char* argv[]) ...@@ -243,10 +243,11 @@ int main(int argc, char* argv[])
// const ck::index_t experts = 8; // const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1})); Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
max_token_id.mData[0] = valid_size; max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 1, 2,2,0,0,0};
int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 7, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for (int i = 0; i < sorted_tile_num; i++) { for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i; expert_ids.mData[i] = eids[i];
} }
int token_per_tile = tokens * topk / valid_tile_num; int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0; int tokenid = 0;
......
...@@ -186,20 +186,27 @@ int main(int argc, char* argv[]) ...@@ -186,20 +186,27 @@ int main(int argc, char* argv[])
// experts = 8 // experts = 8
// per expert: // per expert:
// GEMM shape // GEMM shape
ck::index_t N = 6144; ck::index_t N = 4096;
ck::index_t K = 8192; ck::index_t K = 14336;
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 10; ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 8; ck::index_t valid_tile_num = 13;
ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t tokens = 64; ck::index_t tokens = 512;
ck::index_t topk = 2; ck::index_t topk = 2;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
} }
else if(argc == 3)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7) else if(argc == 7)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
...@@ -233,8 +240,9 @@ int main(int argc, char* argv[]) ...@@ -233,8 +240,9 @@ int main(int argc, char* argv[])
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1})); Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size; max_token_id.mData[0] = valid_size;
int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 7, 7, 3, 3, 3};
for (int i = 0; i < sorted_tile_num; i++) { for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i; expert_ids.mData[i] = eids[i];
} }
if (tokens * topk > valid_size) if (tokens * topk > valid_size)
{ {
......
...@@ -346,27 +346,27 @@ struct DeviceMoeGemm ...@@ -346,27 +346,27 @@ struct DeviceMoeGemm
// } // }
// else // else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ // {
if constexpr (IsGatherGemm) { // if constexpr (IsGatherGemm) {
const auto kernel = kernel_moe_gemm_gather< // const auto kernel = kernel_moe_gemm_gather<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum::Set,
minimum_occupancy, // minimum_occupancy,
TailNumber::Odd>; // TailNumber::Odd>;
RunKernel(kernel); // RunKernel(kernel);
} else { // } else {
const auto kernel = kernel_moe_gemm_scatter< // const auto kernel = kernel_moe_gemm_scatter<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, // minimum_occupancy,
TailNumber::Odd>; // TailNumber::Odd>;
RunKernel(kernel); // RunKernel(kernel);
} // }
} // }
else // else
{ {
if constexpr (IsGatherGemm) { if constexpr (IsGatherGemm) {
const auto kernel = kernel_moe_gemm_gather< const auto kernel = kernel_moe_gemm_gather<
......
...@@ -197,8 +197,8 @@ struct GridwiseMoeGemmGather ...@@ -197,8 +197,8 @@ struct GridwiseMoeGemmGather
__host__ static auto CalculateGridSize(index_t M, index_t N) __host__ static auto CalculateGridSize(index_t M, index_t N)
{ {
return std::make_tuple(math::integer_divide_ceil(N, NPerBlock), return std::make_tuple(math::integer_divide_ceil(N, NPerBlock) * math::integer_divide_ceil(M, MPerBlock),
math::integer_divide_ceil(M, MPerBlock), 1,
1); 1);
} }
...@@ -1140,7 +1140,6 @@ struct GridwiseMoeGemmGather ...@@ -1140,7 +1140,6 @@ struct GridwiseMoeGemmGather
ignore = b_element_op; ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.NumTokens, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); problem.NumTokens, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled = const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>( const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
...@@ -1150,11 +1149,21 @@ struct GridwiseMoeGemmGather ...@@ -1150,11 +1149,21 @@ struct GridwiseMoeGemmGather
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock); c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
// constexpr int expert_tile_cnt[8] = {2, 1, 1, 2, 2, 2, 1, 2};
const index_t expert_block_id = blockIdx.x / problem.NBlock;
// const index_t b_block_id = blockIdx.x % problem.NBlock;
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]);
const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1];
const index_t expert_block_swizzle = expert_block_id / expert_swizzle;
const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle);
if (threadIdx.x==0) {
printf("bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d\n", blockIdx.x, expert_id, expert_swizzle, expert_block_swizzle, b_block_id_swizzle, block_m_id, block_n_id);
}
const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
......
...@@ -573,9 +573,11 @@ struct MoeSortingKernel ...@@ -573,9 +573,11 @@ struct MoeSortingKernel
{ {
int e_start = cumsum[tid]; int e_start = cumsum[tid];
int e_end = cumsum[tid + 1]; int e_end = cumsum[tid + 1];
int e_size = unit_size_mdiv.div(e_end - e_start + unit_size_mdiv.divisor - 1);
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor) for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{ {
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid; p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
p_sorted_expert_cnts[]
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment