Commit d4f34978 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Improve layout kernel performance

parent 01f49071
...@@ -121,9 +121,9 @@ void get_dispatch_layout(const int64_t* topk_idx, ...@@ -121,9 +121,9 @@ void get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_expert, bool* is_token_in_rank, int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts, int num_tokens, int num_topk, int num_ranks, int num_experts,
cudaStream_t stream) { cudaStream_t stream) {
constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8; constexpr int kNumThreads = 256, kNumExpertsPerSM = 4, kNumRanksPerSM = 8;
int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM; int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;
EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM"); EP_STATIC_ASSERT(kNumRanksPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of ranks per SM");
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>), LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment