Unverified Commit 9af0e0d0 authored by Yizhi Wang's avatar Yizhi Wang Committed by GitHub
Browse files

support topk10 in low latency kernel (#403)

parent e9d053c2
......@@ -347,7 +347,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases) {
constexpr int kNumMaxTopK = 9;
constexpr int kNumMaxTopK = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 32 / num_warp_groups;
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
......@@ -928,7 +928,7 @@ void combine(void* combined_x,
bool use_logfmt,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy) {
constexpr int kNumMaxTopk = 9;
constexpr int kNumMaxTopk = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 32 / num_warp_groups;
const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment