Commit 1a24c8b6 authored by lishen's avatar lishen
Browse files

normal-combine深度优化

parent ab0afb04
......@@ -1357,20 +1357,14 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
#pragma unroll
for (int i = lane_id; i < hidden_int4; i += kWarpSize) {
// Read buffers
// TODO: maybe too many registers here
int4 recv_value_int4[kMaxNumRanks];
float values[kDtypePerInt4] = {0}; // 8 × 4B = 32B
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i));
// Reduce all-to-all results
float values[kDtypePerInt4] = {0};
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j) {
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
for (int j = 0; j < num_topk_ranks; ++j) {
int4 recv_value = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i));
auto recv_dtypes = reinterpret_cast<const dtype_t*>(&recv_value);
#pragma unroll
for (int k = 0; k < kDtypePerInt4; ++ k)
values[k] += static_cast<float>(recv_value_dtypes[k]);
for (int k = 0; k < kDtypePerInt4; ++k)
values[k] += static_cast<float>(recv_dtypes[k]);
}
// Cast back to `dtype_t` and write
......@@ -1835,47 +1829,74 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// Iterate over all tokens and combine
// ==================== Token 级展开 x4 ====================
constexpr int kTokenUnroll = 4;
int cached_channel_tail_idx = 0;
for(int64_t token_idx = token_start_idx + target_rank; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
// Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
int expected_head = -1;
if(lane_id < kNumRDMARanks) {
expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
(expected_head < 0) ? (rdma_receiver_rdma_head[target_rank][lane_id] = -expected_head - 1)
: (rdma_receiver_rdma_head[target_rank][lane_id] = expected_head);
}
// Wait lanes to be ready
auto start_time = wall_clock64();
while (cached_channel_tail_idx <= expected_head) {
cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
for (int64_t base = token_start_idx + target_rank;
base < token_end_idx;
base += (int64_t)kNumRDMAReceivers * kTokenUnroll) {
// ---- Phase 1: 批量预取所有 token 的 expected_head ----
int cached_expected_head[kTokenUnroll];
int max_expected_head = -1;
#pragma unroll
for (int u = 0; u < kTokenUnroll; ++u) {
int64_t tidx = base + (int64_t)u * kNumRDMAReceivers;
cached_expected_head[u] = -1;
if (tidx < token_end_idx && lane_id < kNumRDMARanks) {
int expected_head = ld_nc_global(combined_rdma_head + tidx * kNumRDMARanks + lane_id);
cached_expected_head[u] = expected_head;
// 更新 rdma_receiver_rdma_head(coordinator 需要)
(expected_head < 0) ? (rdma_receiver_rdma_head[target_rank][lane_id] = -expected_head - 1)
: (rdma_receiver_rdma_head[target_rank][lane_id] = expected_head);
if (expected_head > max_expected_head) max_expected_head = expected_head;
}
}
// Timeout check
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
trap();
// ---- Phase 2: 一次等待,覆盖所有 token ----
if (max_expected_head >= 0) {
auto start_time = wall_clock64();
while (cached_channel_tail_idx <= max_expected_head) {
cached_channel_tail_idx = static_cast<int>(
ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine RDMA receiver timeout (unroll x%d), "
"ch: %d, rdma: %d, nvl: %d, lane: %d, "
"tail: %d, wait: %d\n",
kTokenUnroll, channel_id, rdma_rank, nvl_rank,
lane_id, cached_channel_tail_idx, max_expected_head);
trap();
}
}
}
syncwarp();
// Combine current token
auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx; };
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks, false>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
combined_x + token_idx * hidden_int4,
combined_topk_weights + token_idx * num_topk,
num_max_rdma_chunked_recv_tokens,
get_addr_fn, recv_tw_fn);
// ---- Phase 3: 批量处理所有就绪 token ----
#pragma unroll
for (int u = 0; u < kTokenUnroll; ++u) {
int64_t tidx = base + (int64_t)u * kNumRDMAReceivers;
if (tidx < token_end_idx) {
int expected_head = cached_expected_head[u];
// Combine current token
auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx; };
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks, false>(
expected_head >= 0, expected_head, lane_id,
hidden_int4, num_topk,
combined_x + tidx * hidden_int4,
combined_topk_weights + tidx * num_topk,
num_max_rdma_chunked_recv_tokens,
get_addr_fn, recv_tw_fn);
}
}
}
// Retired
syncwarp();
if(lane_id == 0) {
if (lane_id == 0) {
rdma_receiver_retired[target_rank] = true;
}
} else if(warp_role == WarpRole::kNVLCoordinator) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment