Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
e1283972
Commit
e1283972
authored
Oct 20, 2025
by
lijian6
Browse files
Fix sync mode error.
Signed-off-by:
lijian
<
lijian6@sugon.com
>
parent
5563b6d0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
3 additions
and
809 deletions
+3
-809
1.sh
1.sh
+1
-2
2.sh
2.sh
+1
-2
csrc/kernels/internode.hip
csrc/kernels/internode.hip
+0
-804
tests/test_internode.py
tests/test_internode.py
+1
-1
No files found.
1.sh
View file @
e1283972
...
@@ -9,6 +9,5 @@ export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
...
@@ -9,6 +9,5 @@ export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_4:1,mlx5_6:1,mlx5_8:1
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_4:1,mlx5_6:1,mlx5_8:1
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
PYTHONPATH
=
/
work
/Tmp/DeepEP:
$PYTHONPATH
export
PYTHONPATH
=
/
public/home/lishen
/Tmp/DeepEP:
$PYTHONPATH
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_internode.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/internode_lj.py
2.sh
View file @
e1283972
...
@@ -9,6 +9,5 @@ export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
...
@@ -9,6 +9,5 @@ export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_4:1,mlx5_6:1,mlx5_8:1
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_4:1,mlx5_6:1,mlx5_8:1
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
PYTHONPATH
=
/
work
/Tmp/DeepEP:
$PYTHONPATH
export
PYTHONPATH
=
/
public/home/lishen
/Tmp/DeepEP:
$PYTHONPATH
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_internode.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/internode_lj.py
csrc/kernels/internode.hip
View file @
e1283972
...
@@ -2059,810 +2059,6 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
...
@@ -2059,810 +2059,6 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
#undef COMBINE_LAUNCH_CASE
#undef COMBINE_LAUNCH_CASE
}
}
/*template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
int lane_id, int hidden_int4, int num_topk,
int4* combined_row, float* combined_topk_weights,
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
// Broadcast current heads
// Lane `i` holds the head of rank `i` and `is_token_in_rank`
EP_STATIC_ASSERT(kMaxNumRanks <= kWarpSize, "Too many ranks");
int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i) if (shfl_sync(is_token_in_rank, i)) {
slot_indices[num_topk_ranks] = shfl_sync(head_idx, i) % num_max_recv_tokens;
topk_ranks[num_topk_ranks ++] = i;
}
EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);
// Reduce data
#pragma unroll
for (int i = lane_id; i < hidden_int4; i += kWarpSize) {
// Read buffers
// TODO: maybe too many registers here
int4 recv_value_int4[kMaxNumRanks];
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
// Reduce all-to-all results
float values[kDtypePerInt4] = {0};
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j) {
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
#pragma unroll
for (int k = 0; k < kDtypePerInt4; ++ k)
values[k] += static_cast<float>(recv_value_dtypes[k]);
}
// Cast back to `dtype_t` and write
int4 out_int4;
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j)
out_dtypes[j] = static_cast<dtype_t>(values[j]);
st_na_global(combined_row + i, out_int4);
}
// Reduce `topk_weights`
if (lane_id < num_topk) {
float value = 0;
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id);
st_na_global(combined_topk_weights + lane_id, value);
}
// Return the minimum top-k rank
return topk_ranks[0];
}
template <bool kLowLatencyMode,
int kNumRDMARanks,
typename dtype_t,
int kNumCombineForwarderWarps,
int kNumTopkRDMARanks,
int kNumWarpsPerForwarder,
int kNumForwarders,
int kNumRDMAReceivers>
__device__ void combine_kNVL_block_kernel_impl(int4* combined_x,
float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const int4* x,
const float* topk_weights,
const int* combined_rdma_head,
const int* combined_nvl_head,
const SourceMeta* src_meta,
const int* rdma_channel_prefix_matrix,
const int* rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix,
int num_tokens,
int num_combined_tokens,
int hidden,
int num_topk,
void* rdma_buffer_ptr,
int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs,
int num_max_nvl_chunked_send_tokens,
int num_max_nvl_chunked_recv_tokens,
int rank,
int num_ranks) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / 3,
channel_id = sm_id / 3;
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
// NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto warp_id = thread_id / kWarpSize;
if(warp_id >= NUM_MAX_NVL_PEERS) {
return;
}
const auto dst_nvl_rank = (warp_id + channel_id) % NUM_MAX_NVL_PEERS;
// auto warp_role = WarpRole::kNVLSender;
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
// NVL layouts
// NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank)
.advance_also(local_buffer_ptr);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank)
.advance_also(local_buffer_ptr);
auto nvl_channel_topk_weights =
AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank)
.advance_also(local_buffer_ptr);
auto nvl_channel_head =
AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
auto nvl_channel_tail =
AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
// Get tasks for each RDMA lane
int token_start_idx = 0, token_end_idx = 0;
if(lane_id < kNumRDMARanks) {
int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
}
syncwarp();
// NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
// Iterate over all tokens and send by chunks
while(true) {
// Exit if possible
if(__all_sync(kFullWarpMask, token_start_idx >= token_end_idx))
break;
// Decide next RDMA buffer to send
bool is_lane_ready = false;
auto start_time = clock64();
while(true) {
int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;
is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and
num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;
if(__any_sync(kFullWarpMask, is_lane_ready))
break;
// Retry
if(lane_id < kNumRDMARanks and token_start_idx < token_end_idx)
cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);
// Timeout check
if(clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, "
"RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n",
channel_id,
rdma_rank,
nvl_rank,
dst_nvl_rank,
lane_id,
ld_volatile_global(nvl_channel_head.buffer() + lane_id),
cached_channel_tail_idx,
token_start_idx,
token_end_idx);
trap();
}
}
// Sync token start index and count
for(int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) {
if(shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
continue;
// Sync token start index
auto token_idx = static_cast<int64_t>(shfl_sync(token_start_idx, current_rdma_idx));
int num_tokens_in_chunk = shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx);
// Send by chunk
for(int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) {
// Get an empty slot
int dst_slot_idx = 0;
if(lane_id == current_rdma_idx) {
dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma;
dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;
}
dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx);
// Copy data
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
// Copy source meta
if(lane_id == num_topk)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
// Copy `topk_weights`
if(lane_id < num_topk)
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id,
ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
}
lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
}
// Move queue tail
syncwarp();
if(lane_id < kNumRDMARanks and is_lane_ready) {
st_relaxed_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
}
}
}
template <bool kLowLatencyMode,
int kNumRDMARanks,
typename dtype_t,
int kNumCombineForwarderWarps,
int kNumTopkRDMARanks,
int kNumWarpsPerForwarder,
int kNumForwarders,
int kNumRDMAReceivers>
__device__ void combine_kRDMAAndNVL_block_kernel_impl(int4* combined_x,
float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const int4* x,
const float* topk_weights,
const int* combined_rdma_head,
const int* combined_nvl_head,
const SourceMeta* src_meta,
const int* rdma_channel_prefix_matrix,
const int* rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix,
int num_tokens,
int num_combined_tokens,
int hidden,
int num_topk,
void* rdma_buffer_ptr,
int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs,
int num_max_nvl_chunked_send_tokens,
int num_max_nvl_chunked_recv_tokens,
int rank,
int num_ranks, rocshmem::rocshmem_ctx_t ctx) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / 3,
channel_id = sm_id / 3;
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
// NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
////////////////////////////////////////////
// EP_DEVICE_ASSERT(kNumForwarders == 8);
if((thread_id / kWarpSize) > kNumForwarders) {
return;
}
enum class WarpRole {
kNVLAndRDMAForwarder,
kCoordinator
};
auto role_meta = [=]() -> std::pair<WarpRole, int> {
auto warp_id = thread_id / kWarpSize;
if(warp_id < kNumForwarders) {
auto shuffled_warp_id = (warp_id + channel_id) % kNumForwarders;
return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id};
} else {
return {WarpRole::kCoordinator, 0};
}
}();
auto warp_role = role_meta.first;
auto warp_id = role_meta.second;
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
// Combiners and coordinators
// RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// NVL layouts
void* local_nvl_buffer = buffer_ptrs[nvl_rank];
void* nvl_buffers[NUM_MAX_NVL_PEERS];
#pragma unroll
for(int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
nvl_buffers[i] = buffer_ptrs[i];
auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels)
.advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels)
.advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels)
.advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_head =
AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
auto nvl_channel_tail =
AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
// Combiner warp synchronization
__shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS];
__shared__ volatile bool forwarder_retired[kNumForwarders];
// auto sync_forwarder_smem = [=]() { asm volatile("bar.sync 0, %0;" ::"r"((kNumForwarders + 1) * kWarpSize)); };
if(warp_role == WarpRole::kNVLAndRDMAForwarder) {
// Receive from NVL ranks and forward to RDMA ranks
// NOTES: this part is using "large warps" for each RDMA ranks
const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder;
const auto sub_warp_id = warp_id % kNumWarpsPerForwarder;
auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);
auto sync_large_warp = [=]() {
if(kNumWarpsPerForwarder == 1) {
syncwarp();
} else {
// asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize));
// __syncthreads();
syncwarp();
}
};
EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= kNumCombineForwarderWarps, "Barriers are not enough");
// Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移
nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
nvl_channel_head.advance(dst_rdma_rank);
nvl_channel_tail.advance(dst_rdma_rank);
// Clean shared memory and sync
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0;
lane_id == 0 ? (forwarder_retired[warp_id] = false) : false;
// sync_forwarder_smem();
__syncthreads();
// Get count and cached head
int cached_nvl_channel_tail_idx = 0;
int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
num_tokens_to_combine -= num_tokens_prefix;
num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;
// Iterate over all tokens and combine by chunks
for(int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {
// Check destination queue emptiness, or wait a buffer to be released
auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
auto num_chunked_tokens = token_end_idx - token_start_idx;
auto start_time = clock64();
while(sub_warp_id == 0 and lane_id == 0) {
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
// Here, `token_start_idx` is the actual tail
int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));
if(num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
break;
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens);
trap();
}
}
sync_large_warp();
// Combine and write to the RDMA buffer
for(int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {
// Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
int expected_head = -1;
if(lane_id < NUM_MAX_NVL_PEERS)
expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
// Wait lanes to be ready
start_time = clock64();
while(cached_nvl_channel_tail_idx <= expected_head) {
cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id));
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {
printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head);
trap();
}
}
// Combine current token
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
reinterpret_cast<int4*>(shifted),
reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
// Update head
if(lane_id < NUM_MAX_NVL_PEERS) {
expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1)
: (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1);
}
}
sync_large_warp();
// Issue RDMA send
if(sub_warp_id == kNumWarpsPerForwarder - 1) {
if(dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
rocshmem::rocshmem_ctx_schar_put_nbi_wave(
ctx,
rdma_channel_data.recv_buffer(rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
rocshmem::rocshmem_ctx_quiet(ctx);
} else {
memory_fence();
}
// Write new RDMA tail
syncwarp();
if(lane_id == 0) {
rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
}
}
// Retired
syncwarp();
if(lane_id == 0) {
forwarder_retired[warp_id] = true;
}
} else {
// Coordinator
// Sync shared memory status
// sync_forwarder_smem();
__syncthreads();
constexpr int num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_nvl_head[kNumRDMARanks] = {0};
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) {
// Retired
if(__all(lane_id >= kNumForwarders or forwarder_retired[lane_id]))
break;
{
// Find minimum head for NVL ranks
#pragma unroll
for(int i = 0; i < kNumRDMARanks; ++i) {
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for(int j = 0; j < num_warps_per_rdma_rank; ++j)
if(not forwarder_retired[i * num_warps_per_rdma_rank + j])
min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);
if(min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) {
st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
}
}
}
// Nanosleep and let other warps work
__builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
}
}
}
template <bool kLowLatencyMode,
int kNumRDMARanks,
typename dtype_t,
int kNumCombineForwarderWarps,
int kNumTopkRDMARanks,
int kNumWarpsPerForwarder,
int kNumForwarders,
int kNumRDMAReceivers>
__device__ void combine_kRDMA_block_kernel_impl(int4* combined_x,
float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const int4* x,
const float* topk_weights,
const int* combined_rdma_head,
const int* combined_nvl_head,
const SourceMeta* src_meta,
const int* rdma_channel_prefix_matrix,
const int* rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix,
int num_tokens,
int num_combined_tokens,
int hidden,
int num_topk,
void* rdma_buffer_ptr,
int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs,
int num_max_nvl_chunked_send_tokens,
int num_max_nvl_chunked_recv_tokens,
int rank,
int num_ranks, rocshmem::rocshmem_ctx_t ctx) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / 3,
channel_id = sm_id / 3;
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
// NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
////////////////////////////////////////////
if((thread_id / kWarpSize) > kNumForwarders) {
return;
}
enum class WarpRole {
kRDMAReceiver,
kCoordinator
};
auto role_meta = [=]() -> std::pair<WarpRole, int> {
auto warp_id = thread_id / kWarpSize;
{
if(warp_id < kNumForwarders) {
return {WarpRole::kRDMAReceiver, warp_id};
} else {
return {WarpRole::kCoordinator, 0};
}
}
}();
auto warp_role = role_meta.first;
auto warp_id = role_meta.second;
// Combiners and coordinators
// RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// Combiner warp synchronization
__shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
__shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
// auto sync_rdma_receiver_smem = [=]() { asm volatile("bar.sync 1, %0;" ::"r"((kNumRDMAReceivers + 1) * kWarpSize)); };
if(warp_role == WarpRole::kRDMAReceiver) {
// Receive from RDMA ranks and write to the output tensor
// Clean shared memory and sync
EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0;
lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0;
// sync_rdma_receiver_smem();
__syncthreads();
// The same tokens as the dispatch process
int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// Iterate over all tokens and combine
int cached_channel_tail_idx = 0;
for(int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
// Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
int expected_head = -1;
if(lane_id < kNumRDMARanks) {
expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
(expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1)
: (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head);
}
// Wait lanes to be ready
auto start_time = clock64();
while (cached_channel_tail_idx <= expected_head) {
cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
trap();
}
}
syncwarp();
// Combine current token
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
combined_x + token_idx * hidden_int4,
combined_topk_weights + token_idx * num_topk,
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
}
// Retired
syncwarp();
if(lane_id == 0) {
rdma_receiver_retired[warp_id] = true;
}
} else {
// Coordinator
// Sync shared memory status
// sync_rdma_receiver_smem();
__syncthreads();
const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_rdma_head = 0;
int last_nvl_head[kNumRDMARanks] = {0};
int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0;
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) {
// Retired
if(__all_sync(kFullWarpMask, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
break;
// Find minimum head for RDMA ranks
{
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for(int i = 0; i < kNumRDMAReceivers; ++i)
if(not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if(min_head != std::numeric_limits<int>::max() and min_head > last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
last_rdma_head = min_head;
}
}
// Nanosleep and let other warps work
__builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
}
}
}
template <bool kLowLatencyMode,
int kNumRDMARanks,
typename dtype_t,
int kNumCombineForwarderWarps,
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks),
int kNumWarpsPerForwarder = 1, // (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder,
int kNumRDMAReceivers = kNumForwarders>
__global__ void __launch_bounds__((1 + NUM_MAX_NVL_PEERS) * kWarpSize, 1) combine(int4* combined_x,
float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const int4* x,
const float* topk_weights,
const int* combined_rdma_head,
const int* combined_nvl_head,
const SourceMeta* src_meta,
const int* rdma_channel_prefix_matrix,
const int* rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix,
int num_tokens,
int num_combined_tokens,
int hidden,
int num_topk,
void* rdma_buffer_ptr,
int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs,
int num_max_nvl_chunked_send_tokens,
int num_max_nvl_chunked_recv_tokens,
int rank,
int num_ranks) {
__shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx);
const auto sm_id = static_cast<int>(blockIdx.x);
if(sm_id % 3 == 0) { // kNVLSender
combine_kNVL_block_kernel_impl<kLowLatencyMode,
kNumRDMARanks, dtype_t,
kNumCombineForwarderWarps,
kNumTopkRDMARanks,
kNumWarpsPerForwarder,
kNumForwarders,
kNumRDMAReceivers>(combined_x, combined_topk_weights,
is_combined_token_in_rank,
x, topk_weights,
combined_rdma_head, combined_nvl_head,
src_meta, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
num_tokens, num_combined_tokens, hidden, num_topk,
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens,
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens,
rank, num_ranks);
} else if(sm_id % 3 == 1) { // kNVLAndRDMAForwarder + kCoordinator 一部分
combine_kRDMAAndNVL_block_kernel_impl<kLowLatencyMode,
kNumRDMARanks,
dtype_t,
kNumCombineForwarderWarps,
kNumTopkRDMARanks,
kNumWarpsPerForwarder,
kNumForwarders,
kNumRDMAReceivers>(combined_x, combined_topk_weights,
is_combined_token_in_rank,
x, topk_weights,
combined_rdma_head, combined_nvl_head,
src_meta, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
num_tokens, num_combined_tokens, hidden, num_topk,
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens,
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens,
rank, num_ranks, ctx);
} else { // kRDMAReceiver + kCoordinator 另一部分
combine_kRDMA_block_kernel_impl<kLowLatencyMode,
kNumRDMARanks,
dtype_t,
kNumCombineForwarderWarps,
kNumTopkRDMARanks,
kNumWarpsPerForwarder,
kNumForwarders,
kNumRDMAReceivers>(combined_x, combined_topk_weights,
is_combined_token_in_rank,
x, topk_weights,
combined_rdma_head, combined_nvl_head,
src_meta, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
num_tokens, num_combined_tokens, hidden, num_topk,
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens,
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens,
rank, num_ranks, ctx);
}
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
}
void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
const bool *is_combined_token_in_rank, const void *x, const float *topk_weights,
const void *bias_0, const void *bias_1, const int *combined_rdma_head,
const int *combined_nvl_head, const void *src_meta,
const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumCombineForwarderWarps = 8;
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \
{ \
auto combine_func = \
low_latency_mode \
? combine<true, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps> \
: combine<false, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>; \
LAUNCH_KERNEL_NON_COOPERATIVE( \
&cfg, combine_func, reinterpret_cast<int4 *>(combined_x), combined_topk_weights, \
is_combined_token_in_rank, reinterpret_cast<const int4 *>(x), topk_weights, \
combined_rdma_head, combined_nvl_head, reinterpret_cast<const SourceMeta *>(src_meta), \
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, \
num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs, \
num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks); \
} \
break
int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
auto num_warps_per_forwarder = ::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder;
EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks >
::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks - num_warps_per_forwarder >=
num_max_nvl_chunked_send_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens >= num_warps_per_forwarder);
EP_HOST_ASSERT(type == HIP_R_16BF);
SETUP_LAUNCH_CONFIG(num_channels * 3,
(NUM_MAX_NVL_PEERS + 1) * kWarpSize,
stream);
SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
}*/
} // namespace internode
} // namespace internode
} // namespace deep_ep
} // namespace deep_ep
...
...
tests/test_internode.py
View file @
e1283972
...
@@ -162,7 +162,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -162,7 +162,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
# print("lijian test dipatch end and combine start.")
# print("lijian test dipatch end and combine start.")
bias_0
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
bias_0
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
bias_1
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
bias_1
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
combine_args
=
{
'x'
:
recv_x
,
'handle'
:
handle
,
'config'
:
config
}
combine_args
=
{
'x'
:
recv_x
,
'handle'
:
handle
,
'config'
:
config
,
'async_finish'
:
async_mode
}
if
with_topk
:
if
with_topk
:
combine_args
.
update
({
'topk_weights'
:
recv_topk_weights
})
combine_args
.
update
({
'topk_weights'
:
recv_topk_weights
})
if
previous_mode
:
if
previous_mode
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment