Commit e1283972 authored by lijian6's avatar lijian6
Browse files

Fix sync mode error.


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 5563b6d0
......@@ -9,6 +9,5 @@ export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_4:1,mlx5_6:1,mlx5_8:1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCSHMEM_HEAP_SIZE=10737418240
export PYTHONPATH=/work/Tmp/DeepEP:$PYTHONPATH
export PYTHONPATH=/public/home/lishen/Tmp/DeepEP:$PYTHONPATH
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/internode_lj.py
......@@ -9,6 +9,5 @@ export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_4:1,mlx5_6:1,mlx5_8:1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCSHMEM_HEAP_SIZE=10737418240
export PYTHONPATH=/work/Tmp/DeepEP:$PYTHONPATH
export PYTHONPATH=/public/home/lishen/Tmp/DeepEP:$PYTHONPATH
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/internode_lj.py
......@@ -2059,810 +2059,6 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
#undef COMBINE_LAUNCH_CASE
}
/*template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
int lane_id, int hidden_int4, int num_topk,
int4* combined_row, float* combined_topk_weights,
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
// Broadcast current heads
// Lane `i` holds the head of rank `i` and `is_token_in_rank`
EP_STATIC_ASSERT(kMaxNumRanks <= kWarpSize, "Too many ranks");
int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i) if (shfl_sync(is_token_in_rank, i)) {
slot_indices[num_topk_ranks] = shfl_sync(head_idx, i) % num_max_recv_tokens;
topk_ranks[num_topk_ranks ++] = i;
}
EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);
// Reduce data
#pragma unroll
for (int i = lane_id; i < hidden_int4; i += kWarpSize) {
// Read buffers
// TODO: maybe too many registers here
int4 recv_value_int4[kMaxNumRanks];
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
// Reduce all-to-all results
float values[kDtypePerInt4] = {0};
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j) {
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
#pragma unroll
for (int k = 0; k < kDtypePerInt4; ++ k)
values[k] += static_cast<float>(recv_value_dtypes[k]);
}
// Cast back to `dtype_t` and write
int4 out_int4;
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j)
out_dtypes[j] = static_cast<dtype_t>(values[j]);
st_na_global(combined_row + i, out_int4);
}
// Reduce `topk_weights`
if (lane_id < num_topk) {
float value = 0;
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id);
st_na_global(combined_topk_weights + lane_id, value);
}
// Return the minimum top-k rank
return topk_ranks[0];
}
template <bool kLowLatencyMode,
int kNumRDMARanks,
typename dtype_t,
int kNumCombineForwarderWarps,
int kNumTopkRDMARanks,
int kNumWarpsPerForwarder,
int kNumForwarders,
int kNumRDMAReceivers>
__device__ void combine_kNVL_block_kernel_impl(int4* combined_x,
float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const int4* x,
const float* topk_weights,
const int* combined_rdma_head,
const int* combined_nvl_head,
const SourceMeta* src_meta,
const int* rdma_channel_prefix_matrix,
const int* rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix,
int num_tokens,
int num_combined_tokens,
int hidden,
int num_topk,
void* rdma_buffer_ptr,
int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs,
int num_max_nvl_chunked_send_tokens,
int num_max_nvl_chunked_recv_tokens,
int rank,
int num_ranks) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / 3,
channel_id = sm_id / 3;
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
// NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto warp_id = thread_id / kWarpSize;
if(warp_id >= NUM_MAX_NVL_PEERS) {
return;
}
const auto dst_nvl_rank = (warp_id + channel_id) % NUM_MAX_NVL_PEERS;
// auto warp_role = WarpRole::kNVLSender;
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
// NVL layouts
// NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank)
.advance_also(local_buffer_ptr);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank)
.advance_also(local_buffer_ptr);
auto nvl_channel_topk_weights =
AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank)
.advance_also(local_buffer_ptr);
auto nvl_channel_head =
AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
auto nvl_channel_tail =
AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
// Get tasks for each RDMA lane
int token_start_idx = 0, token_end_idx = 0;
if(lane_id < kNumRDMARanks) {
int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
}
syncwarp();
// NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
// Iterate over all tokens and send by chunks
while(true) {
// Exit if possible
if(__all_sync(kFullWarpMask, token_start_idx >= token_end_idx))
break;
// Decide next RDMA buffer to send
bool is_lane_ready = false;
auto start_time = clock64();
while(true) {
int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;
is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and
num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;
if(__any_sync(kFullWarpMask, is_lane_ready))
break;
// Retry
if(lane_id < kNumRDMARanks and token_start_idx < token_end_idx)
cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);
// Timeout check
if(clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, "
"RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n",
channel_id,
rdma_rank,
nvl_rank,
dst_nvl_rank,
lane_id,
ld_volatile_global(nvl_channel_head.buffer() + lane_id),
cached_channel_tail_idx,
token_start_idx,
token_end_idx);
trap();
}
}
// Sync token start index and count
for(int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) {
if(shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
continue;
// Sync token start index
auto token_idx = static_cast<int64_t>(shfl_sync(token_start_idx, current_rdma_idx));
int num_tokens_in_chunk = shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx);
// Send by chunk
for(int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) {
// Get an empty slot
int dst_slot_idx = 0;
if(lane_id == current_rdma_idx) {
dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma;
dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;
}
dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx);
// Copy data
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
// Copy source meta
if(lane_id == num_topk)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
// Copy `topk_weights`
if(lane_id < num_topk)
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id,
ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
}
lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
}
// Move queue tail
syncwarp();
if(lane_id < kNumRDMARanks and is_lane_ready) {
st_relaxed_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
}
}
}
template <bool kLowLatencyMode,
int kNumRDMARanks,
typename dtype_t,
int kNumCombineForwarderWarps,
int kNumTopkRDMARanks,
int kNumWarpsPerForwarder,
int kNumForwarders,
int kNumRDMAReceivers>
__device__ void combine_kRDMAAndNVL_block_kernel_impl(int4* combined_x,
float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const int4* x,
const float* topk_weights,
const int* combined_rdma_head,
const int* combined_nvl_head,
const SourceMeta* src_meta,
const int* rdma_channel_prefix_matrix,
const int* rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix,
int num_tokens,
int num_combined_tokens,
int hidden,
int num_topk,
void* rdma_buffer_ptr,
int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs,
int num_max_nvl_chunked_send_tokens,
int num_max_nvl_chunked_recv_tokens,
int rank,
int num_ranks, rocshmem::rocshmem_ctx_t ctx) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / 3,
channel_id = sm_id / 3;
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
// NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
////////////////////////////////////////////
// EP_DEVICE_ASSERT(kNumForwarders == 8);
if((thread_id / kWarpSize) > kNumForwarders) {
return;
}
enum class WarpRole {
kNVLAndRDMAForwarder,
kCoordinator
};
auto role_meta = [=]() -> std::pair<WarpRole, int> {
auto warp_id = thread_id / kWarpSize;
if(warp_id < kNumForwarders) {
auto shuffled_warp_id = (warp_id + channel_id) % kNumForwarders;
return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id};
} else {
return {WarpRole::kCoordinator, 0};
}
}();
auto warp_role = role_meta.first;
auto warp_id = role_meta.second;
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
// Combiners and coordinators
// RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// NVL layouts
void* local_nvl_buffer = buffer_ptrs[nvl_rank];
void* nvl_buffers[NUM_MAX_NVL_PEERS];
#pragma unroll
for(int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
nvl_buffers[i] = buffer_ptrs[i];
auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels)
.advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels)
.advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels)
.advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_head =
AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
auto nvl_channel_tail =
AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
// Combiner warp synchronization
__shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS];
__shared__ volatile bool forwarder_retired[kNumForwarders];
// auto sync_forwarder_smem = [=]() { asm volatile("bar.sync 0, %0;" ::"r"((kNumForwarders + 1) * kWarpSize)); };
if(warp_role == WarpRole::kNVLAndRDMAForwarder) {
// Receive from NVL ranks and forward to RDMA ranks
// NOTES: this part is using "large warps" for each RDMA ranks
const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder;
const auto sub_warp_id = warp_id % kNumWarpsPerForwarder;
auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);
auto sync_large_warp = [=]() {
if(kNumWarpsPerForwarder == 1) {
syncwarp();
} else {
// asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize));
// __syncthreads();
syncwarp();
}
};
EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= kNumCombineForwarderWarps, "Barriers are not enough");
// Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移
nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
nvl_channel_head.advance(dst_rdma_rank);
nvl_channel_tail.advance(dst_rdma_rank);
// Clean shared memory and sync
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0;
lane_id == 0 ? (forwarder_retired[warp_id] = false) : false;
// sync_forwarder_smem();
__syncthreads();
// Get count and cached head
int cached_nvl_channel_tail_idx = 0;
int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
num_tokens_to_combine -= num_tokens_prefix;
num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;
// Iterate over all tokens and combine by chunks
for(int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {
// Check destination queue emptiness, or wait a buffer to be released
auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
auto num_chunked_tokens = token_end_idx - token_start_idx;
auto start_time = clock64();
while(sub_warp_id == 0 and lane_id == 0) {
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
// Here, `token_start_idx` is the actual tail
int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));
if(num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
break;
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens);
trap();
}
}
sync_large_warp();
// Combine and write to the RDMA buffer
for(int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {
// Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
int expected_head = -1;
if(lane_id < NUM_MAX_NVL_PEERS)
expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
// Wait lanes to be ready
start_time = clock64();
while(cached_nvl_channel_tail_idx <= expected_head) {
cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id));
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {
printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head);
trap();
}
}
// Combine current token
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
reinterpret_cast<int4*>(shifted),
reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
// Update head
if(lane_id < NUM_MAX_NVL_PEERS) {
expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1)
: (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1);
}
}
sync_large_warp();
// Issue RDMA send
if(sub_warp_id == kNumWarpsPerForwarder - 1) {
if(dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
rocshmem::rocshmem_ctx_schar_put_nbi_wave(
ctx,
rdma_channel_data.recv_buffer(rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
rocshmem::rocshmem_ctx_quiet(ctx);
} else {
memory_fence();
}
// Write new RDMA tail
syncwarp();
if(lane_id == 0) {
rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
}
}
// Retired
syncwarp();
if(lane_id == 0) {
forwarder_retired[warp_id] = true;
}
} else {
// Coordinator
// Sync shared memory status
// sync_forwarder_smem();
__syncthreads();
constexpr int num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_nvl_head[kNumRDMARanks] = {0};
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) {
// Retired
if(__all(lane_id >= kNumForwarders or forwarder_retired[lane_id]))
break;
{
// Find minimum head for NVL ranks
#pragma unroll
for(int i = 0; i < kNumRDMARanks; ++i) {
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for(int j = 0; j < num_warps_per_rdma_rank; ++j)
if(not forwarder_retired[i * num_warps_per_rdma_rank + j])
min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);
if(min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) {
st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
}
}
}
// Nanosleep and let other warps work
__builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
}
}
}
template <bool kLowLatencyMode,
int kNumRDMARanks,
typename dtype_t,
int kNumCombineForwarderWarps,
int kNumTopkRDMARanks,
int kNumWarpsPerForwarder,
int kNumForwarders,
int kNumRDMAReceivers>
__device__ void combine_kRDMA_block_kernel_impl(int4* combined_x,
float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const int4* x,
const float* topk_weights,
const int* combined_rdma_head,
const int* combined_nvl_head,
const SourceMeta* src_meta,
const int* rdma_channel_prefix_matrix,
const int* rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix,
int num_tokens,
int num_combined_tokens,
int hidden,
int num_topk,
void* rdma_buffer_ptr,
int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs,
int num_max_nvl_chunked_send_tokens,
int num_max_nvl_chunked_recv_tokens,
int rank,
int num_ranks, rocshmem::rocshmem_ctx_t ctx) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / 3,
channel_id = sm_id / 3;
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
// NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
////////////////////////////////////////////
if((thread_id / kWarpSize) > kNumForwarders) {
return;
}
enum class WarpRole {
kRDMAReceiver,
kCoordinator
};
auto role_meta = [=]() -> std::pair<WarpRole, int> {
auto warp_id = thread_id / kWarpSize;
{
if(warp_id < kNumForwarders) {
return {WarpRole::kRDMAReceiver, warp_id};
} else {
return {WarpRole::kCoordinator, 0};
}
}
}();
auto warp_role = role_meta.first;
auto warp_id = role_meta.second;
// Combiners and coordinators
// RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// Combiner warp synchronization
__shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
__shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
// auto sync_rdma_receiver_smem = [=]() { asm volatile("bar.sync 1, %0;" ::"r"((kNumRDMAReceivers + 1) * kWarpSize)); };
if(warp_role == WarpRole::kRDMAReceiver) {
// Receive from RDMA ranks and write to the output tensor
// Clean shared memory and sync
EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0;
lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0;
// sync_rdma_receiver_smem();
__syncthreads();
// The same tokens as the dispatch process
int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// Iterate over all tokens and combine
int cached_channel_tail_idx = 0;
for(int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
// Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
int expected_head = -1;
if(lane_id < kNumRDMARanks) {
expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
(expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1)
: (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head);
}
// Wait lanes to be ready
auto start_time = clock64();
while (cached_channel_tail_idx <= expected_head) {
cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
trap();
}
}
syncwarp();
// Combine current token
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
combined_x + token_idx * hidden_int4,
combined_topk_weights + token_idx * num_topk,
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
}
// Retired
syncwarp();
if(lane_id == 0) {
rdma_receiver_retired[warp_id] = true;
}
} else {
// Coordinator
// Sync shared memory status
// sync_rdma_receiver_smem();
__syncthreads();
const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_rdma_head = 0;
int last_nvl_head[kNumRDMARanks] = {0};
int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0;
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) {
// Retired
if(__all_sync(kFullWarpMask, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
break;
// Find minimum head for RDMA ranks
{
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for(int i = 0; i < kNumRDMAReceivers; ++i)
if(not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if(min_head != std::numeric_limits<int>::max() and min_head > last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
last_rdma_head = min_head;
}
}
// Nanosleep and let other warps work
__builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
}
}
}
template <bool kLowLatencyMode,
int kNumRDMARanks,
typename dtype_t,
int kNumCombineForwarderWarps,
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks),
int kNumWarpsPerForwarder = 1, // (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder,
int kNumRDMAReceivers = kNumForwarders>
__global__ void __launch_bounds__((1 + NUM_MAX_NVL_PEERS) * kWarpSize, 1) combine(int4* combined_x,
float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const int4* x,
const float* topk_weights,
const int* combined_rdma_head,
const int* combined_nvl_head,
const SourceMeta* src_meta,
const int* rdma_channel_prefix_matrix,
const int* rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix,
int num_tokens,
int num_combined_tokens,
int hidden,
int num_topk,
void* rdma_buffer_ptr,
int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs,
int num_max_nvl_chunked_send_tokens,
int num_max_nvl_chunked_recv_tokens,
int rank,
int num_ranks) {
__shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx);
const auto sm_id = static_cast<int>(blockIdx.x);
if(sm_id % 3 == 0) { // kNVLSender
combine_kNVL_block_kernel_impl<kLowLatencyMode,
kNumRDMARanks, dtype_t,
kNumCombineForwarderWarps,
kNumTopkRDMARanks,
kNumWarpsPerForwarder,
kNumForwarders,
kNumRDMAReceivers>(combined_x, combined_topk_weights,
is_combined_token_in_rank,
x, topk_weights,
combined_rdma_head, combined_nvl_head,
src_meta, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
num_tokens, num_combined_tokens, hidden, num_topk,
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens,
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens,
rank, num_ranks);
} else if(sm_id % 3 == 1) { // kNVLAndRDMAForwarder + kCoordinator 一部分
combine_kRDMAAndNVL_block_kernel_impl<kLowLatencyMode,
kNumRDMARanks,
dtype_t,
kNumCombineForwarderWarps,
kNumTopkRDMARanks,
kNumWarpsPerForwarder,
kNumForwarders,
kNumRDMAReceivers>(combined_x, combined_topk_weights,
is_combined_token_in_rank,
x, topk_weights,
combined_rdma_head, combined_nvl_head,
src_meta, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
num_tokens, num_combined_tokens, hidden, num_topk,
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens,
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens,
rank, num_ranks, ctx);
} else { // kRDMAReceiver + kCoordinator 另一部分
combine_kRDMA_block_kernel_impl<kLowLatencyMode,
kNumRDMARanks,
dtype_t,
kNumCombineForwarderWarps,
kNumTopkRDMARanks,
kNumWarpsPerForwarder,
kNumForwarders,
kNumRDMAReceivers>(combined_x, combined_topk_weights,
is_combined_token_in_rank,
x, topk_weights,
combined_rdma_head, combined_nvl_head,
src_meta, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
num_tokens, num_combined_tokens, hidden, num_topk,
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens,
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens,
rank, num_ranks, ctx);
}
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
}
void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
const bool *is_combined_token_in_rank, const void *x, const float *topk_weights,
const void *bias_0, const void *bias_1, const int *combined_rdma_head,
const int *combined_nvl_head, const void *src_meta,
const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumCombineForwarderWarps = 8;
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \
{ \
auto combine_func = \
low_latency_mode \
? combine<true, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps> \
: combine<false, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>; \
LAUNCH_KERNEL_NON_COOPERATIVE( \
&cfg, combine_func, reinterpret_cast<int4 *>(combined_x), combined_topk_weights, \
is_combined_token_in_rank, reinterpret_cast<const int4 *>(x), topk_weights, \
combined_rdma_head, combined_nvl_head, reinterpret_cast<const SourceMeta *>(src_meta), \
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, \
num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs, \
num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks); \
} \
break
int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
auto num_warps_per_forwarder = ::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder;
EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks >
::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks - num_warps_per_forwarder >=
num_max_nvl_chunked_send_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens >= num_warps_per_forwarder);
EP_HOST_ASSERT(type == HIP_R_16BF);
SETUP_LAUNCH_CONFIG(num_channels * 3,
(NUM_MAX_NVL_PEERS + 1) * kWarpSize,
stream);
SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
}*/
} // namespace internode
} // namespace deep_ep
......
......@@ -162,7 +162,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
# print("lijian test dipatch end and combine start.")
bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combine_args = {'x': recv_x, 'handle': handle, 'config': config}
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment