You need to sign in or sign up before continuing.
Unverified Commit bc118b24 authored by Chenggang Zhao's avatar Chenggang Zhao Committed by GitHub
Browse files

Add the transaction window data structure for RDMA senders (#245)

* Add draft

* Add fast-debugging flags

* Fix several bugs

* Add sender timeout checks

* Fix stuck

* Fix bugs

* Fix bugs
parent 9eb2f84b
...@@ -6,6 +6,7 @@ set(CMAKE_VERBOSE_MAKEFILE ON) ...@@ -6,6 +6,7 @@ set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC")
set(CUDA_SEPARABLE_COMPILATION ON) set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
list(APPEND CUDA_NVCC_FLAGS "-O3") list(APPEND CUDA_NVCC_FLAGS "-O3")
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage") list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
......
...@@ -7,9 +7,15 @@ ...@@ -7,9 +7,15 @@
#define NUM_BUFFER_ALIGNMENT_BYTES 128 #define NUM_BUFFER_ALIGNMENT_BYTES 128
#define FINISHED_SUM_TAG 1024 #define FINISHED_SUM_TAG 1024
#define NUM_WAIT_NANOSECONDS 500
#ifndef ENABLE_FAST_DEBUG
#define NUM_CPU_TIMEOUT_SECS 100 #define NUM_CPU_TIMEOUT_SECS 100
#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s #define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s
#define NUM_WAIT_NANOSECONDS 500 #else
#define NUM_CPU_TIMEOUT_SECS 10
#define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s
#endif
#define LOW_LATENCY_SEND_PHASE 1 #define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2 #define LOW_LATENCY_RECV_PHASE 2
......
...@@ -365,7 +365,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -365,7 +365,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
const bool is_forwarder = sm_id % 2 == 0; const bool is_forwarder = sm_id % 2 == 0;
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels || ibgda_get_state()->num_rc_per_pe >= num_sms); EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels or ibgda_get_state()->num_rc_per_pe >= num_sms);
const auto role_meta = [=]() -> std::pair<WarpRole, int> { const auto role_meta = [=]() -> std::pair<WarpRole, int> {
if (is_forwarder) { if (is_forwarder) {
...@@ -419,9 +419,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -419,9 +419,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
// RDMA sender warp synchronization // RDMA sender warp synchronization
__shared__ volatile int rdma_send_next_token_idx; // NOTES: `rdma_send_channel_tail` means the latest released tail
__shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; // NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status
__shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks]; __shared__ int rdma_send_channel_lock[kNumRDMARanks];
__shared__ int rdma_send_channel_tail[kNumRDMARanks];
__shared__ uint32_t rdma_send_channel_window[kNumRDMARanks];
auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); };
// Forward warp synchronization // Forward warp synchronization
...@@ -434,12 +436,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -434,12 +436,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int token_start_idx, token_end_idx; int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// Clean shared memory
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
(warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0;
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0;
// Send number of tokens in this channel by `-value - 1` // Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers");
for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
...@@ -468,24 +464,33 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -468,24 +464,33 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Iterate over tokens and copy into buffer // Iterate over tokens and copy into buffer
int64_t token_idx; int64_t token_idx;
int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1; int cached_rdma_channel_head = 0, global_rdma_tail_idx = 0;
auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) { for (token_idx = token_start_idx; token_idx < token_end_idx; ++ token_idx) {
// Read RDMA rank existence // Read RDMA rank existence
uint64_t is_token_in_rank_uint64 = 0; uint64_t is_token_in_rank_uint64 = 0;
if (lane_id < kNumRDMARanks) if (lane_id < kNumRDMARanks) {
is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); is_token_in_rank_uint64 = __ldg(reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS));
global_rdma_tail_idx += (is_token_in_rank_uint64 != 0);
// Acquire sequential lock }
while (lane_id == 0 and rdma_send_next_token_idx != token_idx);
__syncwarp(); __syncwarp();
// Acquire next tail // Skip the token which does not belong to this warp
int rdma_tail_idx = -1; if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != warp_id)
if (is_token_in_rank_uint64 != 0) { continue;
rdma_tail_idx = rdma_send_channel_next_tail[lane_id] ++; auto rdma_tail_idx = is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1;
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens)
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id))); // Wait the remote buffer to be released
auto start_time = clock64();
while (is_token_in_rank_uint64 != 0 and rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
// Timeout check
if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, rdma_tail_idx);
trap();
}
} }
__syncwarp(); __syncwarp();
...@@ -493,14 +498,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -493,14 +498,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (lane_id < kNumRDMARanks and not kCachedMode) if (lane_id < kNumRDMARanks and not kCachedMode)
send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
// Update last token tail
if (last_rdma_tail_idx >= 0)
st_release_cta(const_cast<const int *>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
last_rdma_tail_idx = rdma_tail_idx;
// Release sequential lock
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
// Broadcast tails // Broadcast tails
SourceMeta src_meta; SourceMeta src_meta;
int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
...@@ -557,24 +554,46 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -557,24 +554,46 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value); st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value); st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
} }
} __syncwarp();
// Epilogue
// Acquire sequential lock
while (lane_id == 0 and rdma_send_next_token_idx != token_idx);
__syncwarp();
// Update last token tail // Release the transaction in the window
if (last_rdma_tail_idx >= 0) if (is_token_in_rank_uint64 != 0) {
st_release_cta(const_cast<const int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); // Acquire lock first
__syncwarp(); acquire_lock(rdma_send_channel_lock + lane_id);
// Release the transaction slot
auto rdy_window = rdma_send_channel_window[lane_id];
auto latest_tail = rdma_send_channel_tail[lane_id];
auto offset = rdma_tail_idx - latest_tail;
// The same effect with `EP_DEVICE_ASSERT(offset < 32);`
EP_STATIC_ASSERT(kNumDispatchRDMASenderWarps < 32, "Invalid warps");
// Erase bit and move the ones if possible
rdy_window ^= 1u << offset;
if (offset == 0) {
EP_DEVICE_ASSERT(rdy_window & 1);
auto num_empty_slots = __ffs(~rdy_window) - 1;
st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots);
rdy_window >>= num_empty_slots;
}
rdma_send_channel_window[lane_id] = rdy_window;
// Release sequential lock // Release lock
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; release_lock(rdma_send_channel_lock + lane_id);
}
__syncwarp();
}
} else if (warp_role == WarpRole::kRDMASenderCoordinator) { } else if (warp_role == WarpRole::kRDMASenderCoordinator) {
// NOTES: in case of splitting, the issued put at the end of the buffer // NOTES: in case of splitting, the issued put at the end of the buffer
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
// Clean shared memory
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
(lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0;
(lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
(lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0;
// Synchronize shared memory // Synchronize shared memory
sync_rdma_sender_smem(); sync_rdma_sender_smem();
...@@ -592,10 +611,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -592,10 +611,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
while (__any_sync(0xffffffff, num_tokens_to_send > 0)) { while (__any_sync(0xffffffff, num_tokens_to_send > 0)) {
// Timeout check // Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail %d, num_tokens_to_send %d\n", printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail: %d, remaining: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, last_issued_tail, num_tokens_to_send); channel_id, rdma_rank, nvl_rank, lane_id, last_issued_tail, num_tokens_to_send);
trap(); trap();
} }
// TODO: try thread-level `put_nbi`?
for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) { for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) {
// To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels // To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels
int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks;
...@@ -603,9 +624,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -603,9 +624,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (synced_num_tokens_to_send == 0) if (synced_num_tokens_to_send == 0)
continue; continue;
// Read progress // Read the latest progress
auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank); // NOTES: `rdma_send_channel_tail` does not need to be protected by lock
auto processed_tail = __shfl_sync(0xffffffff, ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank)), 0); auto processed_tail = __shfl_sync(0xffffffff, ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank)), 0);
auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank);
auto num_tokens_processed = processed_tail - synced_last_issued_tail; auto num_tokens_processed = processed_tail - synced_last_issued_tail;
if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens) if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens)
continue; continue;
...@@ -625,9 +647,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -625,9 +647,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Lighter fence for local RDMA rank // Lighter fence for local RDMA rank
memory_fence(); memory_fence();
} }
__syncwarp();
// Update tails // Update tails
__syncwarp();
if (lane_id == dst_rdma_rank) { if (lane_id == dst_rdma_rank) {
last_issued_tail += num_tokens_to_issue; last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue;
......
...@@ -58,12 +58,9 @@ cfg.dynamicSmemBytes = smem_size; ...@@ -58,12 +58,9 @@ cfg.dynamicSmemBytes = smem_size;
#define SWITCH_RDMA_RANKS(case_macro) \ #define SWITCH_RDMA_RANKS(case_macro) \
switch (num_ranks / NUM_MAX_NVL_PEERS) { \ switch (num_ranks / NUM_MAX_NVL_PEERS) { \
case 2: case_macro(2); \ case 2: case_macro(2); \
case 3: case_macro(3); \
case 4: case_macro(4); \ case 4: case_macro(4); \
case 8: case_macro(8); \ case 8: case_macro(8); \
case 16: case_macro(16); \ case 16: case_macro(16); \
case 18: case_macro(18); \
case 20: case_macro(20); \
default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \ default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
} while (false) } while (false)
...@@ -78,7 +75,6 @@ cfg.dynamicSmemBytes = smem_size; ...@@ -78,7 +75,6 @@ cfg.dynamicSmemBytes = smem_size;
#define SWITCH_TYPES(case_macro) \ #define SWITCH_TYPES(case_macro) \
switch (type) { \ switch (type) { \
case CUDA_R_16BF: case_macro(nv_bfloat16); \ case CUDA_R_16BF: case_macro(nv_bfloat16); \
case CUDA_R_32F: case_macro(float); \
default: EP_HOST_ASSERT(false && "Unsupported type"); \ default: EP_HOST_ASSERT(false && "Unsupported type"); \
} while (false) } while (false)
......
...@@ -466,4 +466,26 @@ barrier_block(int** barrier_signal_ptrs, int rank) { ...@@ -466,4 +466,26 @@ barrier_block(int** barrier_signal_ptrs, int rank) {
__syncthreads(); __syncthreads();
} }
__forceinline__ __device__ int atomic_cas_cta_acquire(int* addr, int x, int y) {
int ret;
asm volatile("atom.acquire.cta.shared::cta.cas.b32 %0, [%1], %2, %3;" : "=r"(ret) : "l"(addr), "r"(x), "r"(y) : "memory");
return ret;
}
__forceinline__ __device__ int atomic_exch_cta_release(int* addr, int x) {
int ret;
asm volatile("atom.release.cta.shared::cta.exch.b32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(x) : "memory");
return ret;
}
__forceinline__ __device__ void acquire_lock(int* mutex) {
// To make later memory operations valid, we must use `acquire` for memory semantics
while (atomic_cas_cta_acquire(mutex, 0, 1) != 0);
}
__forceinline__ __device__ void release_lock(int* mutex) {
// To make previous memory operations visible to other threads, we must use `release` for memory semantics
atomic_exch_cta_release(mutex, 0);
}
} // namespace deep_ep } // namespace deep_ep
...@@ -220,7 +220,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in ...@@ -220,7 +220,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
def test_loop(local_rank: int, num_local_ranks: int): def test_loop(local_rank: int, num_local_ranks: int):
num_nodes = int(os.getenv('WORLD_SIZE', 1)) num_nodes = int(os.getenv('WORLD_SIZE', 1))
rank, num_ranks, group = init_dist(local_rank, num_local_ranks) rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = True test_ll_compatibility = os.getenv('EP_TEST_LL_COMPATIBILITY', False)
if test_ll_compatibility: if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment