Commit d4f34978 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Improve layout kernel performance

parent 01f49071
......@@ -121,9 +121,9 @@ void get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
cudaStream_t stream) {
constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8;
constexpr int kNumThreads = 256, kNumExpertsPerSM = 4, kNumRanksPerSM = 8;
int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;
EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM");
EP_STATIC_ASSERT(kNumRanksPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of ranks per SM");
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment