Commit 8c44b1fa authored by lishen's avatar lishen
Browse files

fix:解决高吞吐压测出现的精度bug

parent 23e211e6
...@@ -1848,9 +1848,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1848,9 +1848,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
if (tidx < token_end_idx && lane_id < kNumRDMARanks) { if (tidx < token_end_idx && lane_id < kNumRDMARanks) {
int expected_head = ld_nc_global(combined_rdma_head + tidx * kNumRDMARanks + lane_id); int expected_head = ld_nc_global(combined_rdma_head + tidx * kNumRDMARanks + lane_id);
cached_expected_head[u] = expected_head; cached_expected_head[u] = expected_head;
// 更新 rdma_receiver_rdma_head(coordinator 需要)
(expected_head < 0) ? (rdma_receiver_rdma_head[target_rank][lane_id] = -expected_head - 1)
: (rdma_receiver_rdma_head[target_rank][lane_id] = expected_head);
if (expected_head > max_expected_head) max_expected_head = expected_head; if (expected_head > max_expected_head) max_expected_head = expected_head;
} }
} }
...@@ -1890,6 +1887,11 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1890,6 +1887,11 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
combined_topk_weights + tidx * num_topk, combined_topk_weights + tidx * num_topk,
num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_recv_tokens,
get_addr_fn, recv_tw_fn); get_addr_fn, recv_tw_fn);
if (lane_id < kNumRDMARanks) {
rdma_receiver_rdma_head[target_rank][lane_id] =
expected_head < 0 ? -expected_head - 1 : expected_head;
}
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment