Commit 4b8d4b15 authored by lijian6's avatar lijian6
Browse files

fix cached_notify err when sm greater than 32.


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent e45581db
...@@ -1213,61 +1213,56 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i ...@@ -1213,61 +1213,56 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
// Barrier again // Barrier again
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
} else if (sm_id == 1) { } else {
if (is_cached_dispatch) if (is_cached_dispatch)
return; return;
EP_DEVICE_ASSERT(num_warps >= num_channels); constexpr int num_clean_sms = 1;
EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize); const int logical_block = sm_id - num_clean_sms;
const int total_blocks = gridDim.x - num_clean_sms;
// Iterate in reverse order if (logical_block < 0) return;
if (lane_id < num_rdma_ranks and warp_id < num_channels) { if (combined_rdma_head != nullptr) {
EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize);
for (int chan = logical_block; chan < num_channels; chan += total_blocks) {
int token_start_idx, token_end_idx; int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, get_channel_task_range(num_combined_tokens, num_channels, chan, token_start_idx, token_end_idx);
token_end_idx);
// NOTES: `1 << 25` is a heuristic large number for (int token_idx = token_end_idx - 1 - warp_id; token_idx >= token_start_idx; token_idx -= num_warps) {
int last_head = 1 << 25; int last_head = 1 << 25;
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { if (lane_id < num_rdma_ranks) {
auto current_head = auto ptr = combined_rdma_head + token_idx * num_rdma_ranks + lane_id;
__ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); int current_head = __ldg(ptr);
if (current_head < 0) { if (current_head < 0) {
combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; *ptr = -last_head - 1;
} else { } else {
last_head = current_head; last_head = current_head;
} }
} }
} }
} else { }
if (is_cached_dispatch) }
return; if (combined_nvl_head != nullptr) {
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr);
EP_DEVICE_ASSERT(num_warps >= num_channels); EP_DEVICE_ASSERT(rdma_rank_prefix_sum != nullptr);
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
rdma_rank_prefix_sum != nullptr);
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers"); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers");
constexpr int num_clean_sms = 2;
if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
for (int dst_rdma_rank = sm_id - num_clean_sms; dst_rdma_rank < num_rdma_ranks;
dst_rdma_rank += num_channels * 2 - num_clean_sms) {
// Iterate in reverse order
int token_start_idx =
warp_id == 0
? 0
: rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
int token_end_idx =
rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
token_start_idx += shift, token_end_idx += shift;
// NOTES: `1 << 25` is a heuristic large number for (int chan = logical_block; chan < num_channels; chan += total_blocks) {
for (int dst_rdma_rank = 0; dst_rdma_rank < num_rdma_ranks; ++dst_rdma_rank) {
int token_start_idx = (chan == 0) ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + chan - 1];
int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + chan];
int shift = (dst_rdma_rank == 0) ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
token_start_idx += shift;
token_end_idx += shift;
for (int token_idx = token_end_idx - 1 - warp_id; token_idx >= token_start_idx; token_idx -= num_warps) {
int last_head = 1 << 25; int last_head = 1 << 25;
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { if (lane_id < NUM_MAX_NVL_PEERS) {
auto current_head = auto ptr = combined_nvl_head +
__ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); token_idx * NUM_MAX_NVL_PEERS + lane_id;
int current_head = __ldg(ptr);
if (current_head < 0) { if (current_head < 0) {
combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1; *ptr = -last_head - 1;
} else { } else {
last_head = current_head; last_head = current_head;
} }
...@@ -1275,6 +1270,8 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i ...@@ -1275,6 +1270,8 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
} }
} }
} }
}
}
} }
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment