Unverified Commit b7fef496 authored by sky's avatar sky Committed by GitHub
Browse files

Fix: avoid floating point exception (#379)



* Fix: avoid floating point exception.
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* simplify the code.
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

---------
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>
parent 1da73be0
...@@ -932,10 +932,11 @@ void combine(void* combined_x, ...@@ -932,10 +932,11 @@ void combine(void* combined_x,
const int num_warp_groups = ceil_div(num_experts, num_device_sms); const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 32 / num_warp_groups; const int num_warps_per_group = 32 / num_warp_groups;
const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms); const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and ((num_combined_tokens == 0) or (num_recv_per_sm > 0))); EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
const auto num_warps = num_warp_groups * num_warps_per_group; const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_sms = max(ceil_div(num_experts, num_warp_groups), ceil_div(num_combined_tokens, num_recv_per_sm)); const auto num_sms = max(ceil_div(num_experts, num_warp_groups),
num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm));
// Check workspace // Check workspace
auto atomic_clean_flag = static_cast<int*>(workspace); auto atomic_clean_flag = static_cast<int*>(workspace);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment