Unverified Commit 1da73be0 authored by Zhiyi Hu's avatar Zhiyi Hu Committed by GitHub
Browse files

fix combine timeout due to delayed forwarder min head update (#353)



* fix combine timeout due to forwarder min head update

* Update head before and after combine_token; add assertion for nvl_buffer_size_per_rdma_rank

---------
Co-authored-by: default avatarzhiyi Hu <zhiyihu@U-NYQQMGK0-2250.local>
parent ab484794
...@@ -1647,8 +1647,10 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1647,8 +1647,10 @@ combine(int4* combined_x, float* combined_topk_weights,
// Read expected head // Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
int expected_head = -1; int expected_head = -1;
if (lane_id < NUM_MAX_NVL_PEERS) if (lane_id < NUM_MAX_NVL_PEERS) {
expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) : (forwarder_nvl_head[warp_id][lane_id] = expected_head);
}
// Wait lanes to be ready // Wait lanes to be ready
start_time = clock64(); start_time = clock64();
...@@ -1851,6 +1853,7 @@ void combine(cudaDataType_t type, ...@@ -1851,6 +1853,7 @@ void combine(cudaDataType_t type,
EP_HOST_ASSERT(num_forwarder_warps > NUM_MAX_NVL_PEERS and num_forwarder_warps % num_rdma_ranks == 0); EP_HOST_ASSERT(num_forwarder_warps > NUM_MAX_NVL_PEERS and num_forwarder_warps % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens)); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks - num_warps_per_forwarder >= num_max_nvl_chunked_send_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens >= num_warps_per_forwarder); EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens >= num_warps_per_forwarder);
EP_HOST_ASSERT(type == CUDA_R_16BF); EP_HOST_ASSERT(type == CUDA_R_16BF);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment