Unverified Commit a2fa3b73 authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Use TMA to optimize internode dispatch. (#276)



* Add TMA buffer allocation

* Use TMA for forwarders and NVL receivers

* Use lane 31 to operate TMA.

* Change rdma buffer layout.

* Use TMA to transfer scales also.

* Increase the NVL recv buffer size.

* Disable early stopping.

* Apply similar optimizations on receiver warps.

* Prevent warp divergence.

* Disable aggressive ptx by default.

* Revert using TMA to transfer scales.

* Format.

* Change the layout of dispatch NVL buffer.

* Move topk transformation to recv warps.

* Use TMA to transfer all data in foward warps

* Use TMA to store scales.

* Code lint

---------
Co-authored-by: default avatarChenggang Zhao <chenggangz@deepseek.com>
parent 7705f533
......@@ -44,21 +44,29 @@ int get_num_bytes_per_rdma_token(int hidden_int4, int num_scales, int num_topk_i
}
__host__ __device__ __forceinline__
std::pair<int, int> get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
std::pair<int, int> get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx,
int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens,
int num_channels) {
// Return `int32_t` offset and count to clean
return {
(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms
(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels
};
}
__host__ __device__ __forceinline__
std::pair<int, int> get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens, int num_sms) {
std::pair<int, int> get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx,
int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks,
int num_nvl_recv_buffer_tokens, int num_channels, bool is_dispatch) {
// Return `int32_t` offset and to clean
// TODO: remove `is_dispatch` after finishing combine refactor
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`");
const int num_bytes_per_token = is_dispatch ?
get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) :
(hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float) + sizeof(SourceMeta));
return {
(num_nvl_recv_buffer_tokens * (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float) + sizeof(SourceMeta)) * num_nvl_ranks * num_sms) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
(num_nvl_recv_buffer_tokens * num_bytes_per_token * num_nvl_ranks * num_channels) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels,
};
}
......@@ -316,7 +324,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
// Get clean meta
auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels);
auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels, true);
EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes);
EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes);
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
......@@ -333,7 +341,7 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
return num_rdma_ranks < 8 ? num_rdma_ranks : 8;
}
template <bool kLowLatencyMode, int kNumRDMARanks, bool kCachedMode,
template <bool kLowLatencyMode, int kNumRDMARanks, bool kCachedMode, int kNumTMABytesPerWarp,
int kNumDispatchRDMASenderWarps, int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
__global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1)
dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, SourceMeta* recv_src_meta,
......@@ -391,8 +399,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// RDMA symmetric layout
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto scale_bytes = num_scales * sizeof(float);
// TODO: rename `num_bytes_per_rdma_token` after combine refactor
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_data = SymBuffer<uint8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
......@@ -407,11 +417,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
// Allocate buffers
auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_x_scales = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_topk_idx = AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_topk_weights = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_x = AsymBuffer<uint8_t>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_bytes_per_rdma_token, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);
......@@ -425,6 +431,19 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
__shared__ uint32_t rdma_send_channel_window[kNumRDMARanks];
auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); };
// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto tma_buffer = smem_tma_buffer + target_rank * kNumTMABytesPerWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + hidden_bytes);
uint32_t tma_phase = 0;
if ((warp_role == WarpRole::kRDMAAndNVLForwarder or warp_role == WarpRole::kNVLReceivers) and lane_id == 0) {
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
EP_DEVICE_ASSERT(num_bytes_per_rdma_token + sizeof(uint64_t) <= kNumTMABytesPerWarp);
}
__syncwarp();
// Forward warp synchronization
__shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];
__shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];
......@@ -524,13 +543,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
for (int i = 0; i < num_topk_ranks; ++ i)
dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;
// Copy source metadata into symmetric send buffer
if (lane_id < num_topk_ranks)
st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
// Copy `x_scales` into symmetric send buffer
#pragma unroll
for (int i = lane_id; i < num_scales; i += 32) {
......@@ -544,6 +556,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
for (int i = 0; i < num_topk_ranks; ++ i)
dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;
// Copy source metadata into symmetric send buffer
if (lane_id < num_topk_ranks)
st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
// Copy `topk_idx` and `topk_weights` into symmetric send buffer
#pragma unroll
for (int i = lane_id; i < num_topk * num_topk_ranks; i += 32) {
......@@ -661,9 +680,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
} else if (warp_role == WarpRole::kRDMAAndNVLForwarder) {
// RDMA consumers and NVL producers
const auto dst_nvl_rank = target_rank;
const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank;
const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks);
const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks);
// Wait counters to arrive
int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0;
......@@ -717,20 +733,19 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) {
// Check destination queue emptiness, or wait a buffer to be released
start_time = clock64();
while (lane_id == 0) {
int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;
while (true) {
const int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;
if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)
break;
cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
cached_nvl_channel_head = __shfl_sync(0xffffffffu, ld_volatile_global(nvl_channel_head.buffer()), 0);
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
if (lane_id == 0 and clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
trap();
}
}
__syncwarp();
// Find next source RDMA rank (round-robin)
start_time = clock64();
......@@ -756,8 +771,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Iterate over every token from the RDMA buffer
for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) {
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(static_cast<int8_t*>(shifted) + hidden_bytes));
auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(shifted + hidden_bytes + scale_bytes));
lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0;
bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
if (lane_id == src_rdma_rank) {
......@@ -771,44 +786,26 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Get an empty slot
int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens;
auto dst_shifted = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_rdma_token;
// Copy data
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
reinterpret_cast<int4*>(shifted),
ld_nc_global, st_na_global);
shifted = static_cast<int4*>(shifted) + hidden_int4;
// Copy source meta
if (lane_id == 0)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
shifted = static_cast<SourceMeta*>(shifted) + 1;
// Copy `x_scales`
UNROLLED_WARP_COPY(1, lane_id, num_scales,
nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
reinterpret_cast<float*>(shifted),
ld_nc_global, st_na_global);
shifted = static_cast<float*>(shifted) + num_scales;
// Copy `topk_idx` and `topk_weights`
// NOTES: do not use `shifted` after this `if`, because only several lanes are shifted
if (lane_id < num_topk) {
// Read
auto idx_value = ld_nc_global(static_cast<int*>(shifted) + lane_id);
shifted = static_cast<int*>(shifted) + num_topk;
auto weight_value = ld_nc_global(static_cast<float*>(shifted) + lane_id);
// Transform and write
idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value);
weight_value = idx_value >= 0 ? weight_value : 0.0f;
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value);
// Copy data
if (lane_id == 0) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_rdma_token, false);
mbarrier_arrive_and_expect_tx(tma_mbarrier, num_bytes_per_rdma_token);
}
__syncwarp();
mbarrier_wait(tma_mbarrier, tma_phase);
if (lane_id == 0)
tma_store_1d(tma_buffer, dst_shifted, num_bytes_per_rdma_token);
__syncwarp();
// In case of insufficient NVL buffers, early stopping
if ((++ num_tokens_sent) == num_max_nvl_chunked_send_tokens)
src_rdma_tail = i + 1;
// Wait TMA to be finished
tma_store_wait();
__syncwarp();
}
// Sync head index
......@@ -866,6 +863,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// NVL consumers
// Retrieve rank offset from barrier results (each lane's register stores an RDMA rank)
int src_nvl_rank = target_rank, total_offset = 0;
const int local_expert_begin = rank * (num_experts / num_ranks);
const int local_expert_end = local_expert_begin + (num_experts / num_ranks);
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0)
total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];
......@@ -900,58 +900,81 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
while (num_tokens_to_recv > 0) {
// Check channel status by lane 0
start_time = clock64();
while (lane_id == 0) {
while (true) {
// Ready to copy
if (cached_channel_head_idx != cached_channel_tail_idx)
break;
cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer());
cached_channel_tail_idx = __shfl_sync(0xffffffff, ld_acquire_sys_global(nvl_channel_tail.buffer()), 0);
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
if (lane_id == 0 and clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx);
trap();
}
}
// Sync queue tail
cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0);
// Copy data
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) {
int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens;
auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);
auto shifted = nvl_channel_x.buffer() + token_idx_in_buffer * num_bytes_per_rdma_token;
auto meta = ld_nc_global(reinterpret_cast<SourceMeta*>(shifted + hidden_bytes + scale_bytes));
int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank);
(lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;
// Copy data
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
recv_x + recv_token_idx * hidden_int4,
nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,
ld_nc_global, st_na_global);
bool scale_aligned = (scale_bytes % 16 == 0);
auto tma_load_bytes = hidden_bytes + (scale_aligned ? scale_bytes : 0);
// Copy data
if (lane_id == 0) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, tma_load_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, tma_load_bytes);
}
__syncwarp();
mbarrier_wait(tma_mbarrier, tma_phase);
if (lane_id == 0)
tma_store_1d(tma_buffer, recv_x + recv_token_idx * hidden_int4, hidden_bytes, false);
__syncwarp();
shifted += hidden_bytes;
// Copy scales
// TODO: make it as templated
if (scale_aligned) {
tma_store_1d(tma_buffer + hidden_bytes, recv_x_scales + recv_token_idx * num_scales, scale_bytes, false);
} else {
UNROLLED_WARP_COPY(1, lane_id, num_scales,
recv_x_scales + recv_token_idx * num_scales,
reinterpret_cast<float*>(shifted),
ld_nc_global, st_na_global);
}
shifted += scale_bytes;
// Copy source meta
if (lane_id == 0 and not kCachedMode)
st_na_global(recv_src_meta + recv_token_idx, meta);
// Copy scales
UNROLLED_WARP_COPY(1, lane_id, num_scales,
recv_x_scales + recv_token_idx * num_scales,
nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,
ld_nc_global, st_na_global);
shifted += sizeof(SourceMeta);
// Copy `topk_idx` and `topk_weights`
if (lane_id < num_topk) {
// Read
auto idx_value = static_cast<int64_t>(ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id));
auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted + sizeof(int) * num_topk) + lane_id);
auto recv_idx = recv_token_idx * num_topk + lane_id;
auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;
st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));
st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));
// Transform and write
idx_value = (idx_value >= local_expert_begin and idx_value < local_expert_end) ? idx_value - local_expert_begin : -1;
weight_value = idx_value >= 0 ? weight_value : 0.0f;
st_na_global(recv_topk_idx + recv_idx, idx_value);
st_na_global(recv_topk_weights + recv_idx, weight_value);
}
// Wait TMA to be finished
tma_store_wait();
__syncwarp();
}
// Move queue
__syncwarp();
if (lane_id == 0)
st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
}
......@@ -972,14 +995,19 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
int rank, int num_ranks, bool is_cached_dispatch,
cudaStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumDispatchRDMASenderWarps = 7;
constexpr int kNumTMABytesPerWarp = 16384;
constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS;
// Make sure never OOB
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
auto dispatch_func = low_latency_mode ? \
(is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>) : \
(is_cached_dispatch ? dispatch<false, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<false, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>); \
(is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumTMABytesPerWarp, kNumDispatchRDMASenderWarps> : \
dispatch<true, num_rdma_ranks, false, kNumTMABytesPerWarp, kNumDispatchRDMASenderWarps>) : \
(is_cached_dispatch ? dispatch<false, num_rdma_ranks, true, kNumTMABytesPerWarp, kNumDispatchRDMASenderWarps> : \
dispatch<false, num_rdma_ranks, false, kNumTMABytesPerWarp, kNumDispatchRDMASenderWarps>); \
SET_SHARED_MEMORY_FOR_TMA(dispatch_func); \
LAUNCH_KERNEL(&cfg, dispatch_func, \
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_topk_idx, recv_topk_weights, reinterpret_cast<SourceMeta*>(recv_src_meta), \
reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
......@@ -1117,7 +1145,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
// Get clean meta
auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels);
auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels, is_cached_dispatch);
EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes);
EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes);
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
......
......@@ -55,7 +55,7 @@ if __name__ == '__main__':
os.environ['DISABLE_AGGRESSIVE_PTX_INSTRS'] = '1'
# Disable aggressive PTX instructions
if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '0')):
if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '1')):
cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
......
......@@ -234,7 +234,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_sms = 24
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=args.test_ll_compatibility,
buffer = deep_ep.Buffer(group, int(2e9), int(1e9), low_latency_mode=args.test_ll_compatibility,
num_qps_per_rank=num_qps_per_rank)
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment