Unverified Commit 2012e310 authored by Chenggang Zhao's avatar Chenggang Zhao Committed by GitHub
Browse files

Canonicalize TMA usages (#410)

* Remove redundant TMA flushes

* Less barrier initialization overhead

* Simplify `elect_one_sync`

* Use `elect_one_sync` instead of lanes

* Minor fix

* Polish testing prints

* Refactor for internode kernels

* Better performance
parent 9af0e0d0
......@@ -275,7 +275,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]);
// Write into channel matrix
if (lane_id == 0) {
if (elect_one_sync()) {
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * num_channels + channel_id] = per_nvl_rank_count[i];
......@@ -446,9 +446,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
auto tma_buffer = smem_tma_buffer + target_rank * kNumTMABytesPerWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + num_bytes_per_token);
uint32_t tma_phase = 0;
if ((warp_role == WarpRole::kRDMAAndNVLForwarder or warp_role == WarpRole::kNVLReceivers) and lane_id == 0) {
if ((warp_role == WarpRole::kRDMAAndNVLForwarder or warp_role == WarpRole::kNVLReceivers) and elect_one_sync()) {
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= kNumTMABytesPerWarp);
}
......@@ -750,7 +749,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
cached_nvl_channel_head = __shfl_sync(0xffffffffu, ld_volatile_global(nvl_channel_head.buffer()), 0);
// Timeout check
if (lane_id == 0 and clock64() - start_time > NUM_TIMEOUT_CYCLES) {
if (elect_one_sync() and clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
trap();
......@@ -799,13 +798,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
auto dst_shifted = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token;
// Copy data
if (lane_id == 0) {
if (elect_one_sync()) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_token, false);
mbarrier_arrive_and_expect_tx(tma_mbarrier, num_bytes_per_token);
}
__syncwarp();
mbarrier_wait(tma_mbarrier, tma_phase);
if (lane_id == 0)
if (elect_one_sync())
tma_store_1d(tma_buffer, dst_shifted, num_bytes_per_token);
__syncwarp();
......@@ -814,7 +813,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
src_rdma_tail = i + 1;
// Wait TMA to be finished
tma_store_wait();
tma_store_wait<0>();
__syncwarp();
}
......@@ -824,13 +823,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Move tail index
__syncwarp();
if (lane_id == 0)
if (elect_one_sync())
st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);
}
// Retired
__syncwarp();
if (lane_id == 0)
if (elect_one_sync())
forward_channel_retired[dst_nvl_rank] = true;
} else if (warp_role == WarpRole::kForwarderCoordinator) {
// Extra warps for forwarder coordinator should exit directly
......@@ -917,7 +916,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
cached_channel_tail_idx = __shfl_sync(0xffffffff, ld_acquire_sys_global(nvl_channel_tail.buffer()), 0);
// Timeout check
if (lane_id == 0 and clock64() - start_time > NUM_TIMEOUT_CYCLES) {
if (elect_one_sync() and clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx);
trap();
......@@ -937,13 +936,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
auto tma_load_bytes = hidden_bytes + (scale_aligned ? scale_bytes : 0);
// Copy data
if (lane_id == 0) {
if (elect_one_sync()) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, tma_load_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, tma_load_bytes);
}
__syncwarp();
mbarrier_wait(tma_mbarrier, tma_phase);
if (lane_id == 0) {
if (elect_one_sync()) {
tma_store_1d(tma_buffer, recv_x + recv_token_idx * hidden_int4, hidden_bytes, false);
if (scale_aligned)
tma_store_1d(tma_buffer + hidden_bytes, recv_x_scales + recv_token_idx * num_scales, scale_bytes, false);
......@@ -962,7 +961,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
shifted += scale_bytes;
// Copy source meta
if (lane_id == 0 and not kCachedMode)
if (not kCachedMode and elect_one_sync())
st_na_global(recv_src_meta + recv_token_idx, meta);
shifted += sizeof(SourceMeta);
......@@ -981,12 +980,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
}
// Wait TMA to be finished
tma_store_wait();
tma_store_wait<0>();
__syncwarp();
}
// Move queue
if (lane_id == 0)
if (elect_one_sync())
st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
}
}
......@@ -1136,9 +1135,8 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
auto tma_buffer = smem_tma_buffer + warp_id * kNumTMABytesPerWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + tma_batch_size);
uint32_t tma_phase = 0;
if (lane_id == 0) {
if (elect_one_sync()) {
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
}
__syncwarp();
......@@ -1155,7 +1153,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
for (int batch_end_idx = token_end_idx; batch_end_idx > token_start_idx; batch_end_idx -= num_tokens_per_batch) {
auto batch_start_idx = max(token_start_idx, batch_end_idx - num_tokens_per_batch);
if (lane_id == 0) {
if (elect_one_sync()) {
tma_load_1d(tma_buffer, combined_nvl_head + batch_start_idx * NUM_MAX_NVL_PEERS, tma_mbarrier, (batch_end_idx - batch_start_idx) * num_bytes_per_token);
mbarrier_arrive_and_expect_tx(tma_mbarrier, (batch_end_idx - batch_start_idx) * num_bytes_per_token);
}
......@@ -1175,9 +1173,9 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
tma_store_fence();
__syncwarp();
if (lane_id == 0)
if (elect_one_sync())
tma_store_1d(tma_buffer, combined_nvl_head + batch_start_idx * NUM_MAX_NVL_PEERS, (batch_end_idx - batch_start_idx) * num_bytes_per_token);
tma_store_wait();
tma_store_wait<0>();
__syncwarp();
}
}
......@@ -1222,7 +1220,9 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
is_cached_dispatch, cpu_rdma_team);
}
template <int kNumRanks, bool kMaybeWithBias, typename dtype_t, int kMaxNumRanks, bool kUseTMA, int kNumStages, int kNumTMALoadBytes = 0, typename GetAddrFn, typename ReceiveTWFn>
template <int kNumRanks, bool kMaybeWithBias, typename dtype_t, int kMaxNumRanks,
bool kUseTMA, int kNumStages, int kNumTMALoadBytes = 0,
typename GetAddrFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
int lane_id, int hidden_int4, int num_topk,
int4* combined_row, float* combined_topk_weights,
......@@ -1281,7 +1281,10 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
values[k] += static_cast<float>(recv_value_dtypes[k]);
}
// Wait shared memory to be released
tma_store_wait<kNumStages - 1>();
// Copy into shared and issue TMA
auto out_dtypes = reinterpret_cast<dtype_t*>(tma_store_buffer(stage_idx) + lane_id);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j)
......@@ -1289,13 +1292,13 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
tma_store_fence();
__syncwarp();
if (lane_id == 0)
tma_store_1d(tma_store_buffer(stage_idx), combined_row + shifted + lane_id, kNumTMALoadBytes);
if (elect_one_sync())
tma_store_1d(tma_store_buffer(stage_idx), combined_row + shifted, kNumTMALoadBytes);
__syncwarp();
}
// Flush all writes
tma_store_wait();
tma_store_wait<0>();
} else {
#pragma unroll
for (int i = lane_id; i < hidden_int4; i += 32) {
......@@ -1441,9 +1444,8 @@ combine(int4* combined_x, float* combined_topk_weights,
auto tma_buffer = smem_tma_buffer + dst_nvl_rank * kNumTMABytesPerSenderWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + num_bytes_per_token);
uint32_t tma_phase = 0;
if (lane_id == 0) {
if (elect_one_sync()) {
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= kNumTMABytesPerSenderWarp);
}
......@@ -1514,8 +1516,8 @@ combine(int4* combined_x, float* combined_topk_weights,
// Load data
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token;
auto shifted_x = x + token_idx * hidden_int4;
if (lane_id == 0) {
tma_store_wait();
if (elect_one_sync()) {
tma_store_wait<0>();
tma_load_1d(tma_buffer, shifted_x, tma_mbarrier, hidden_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, hidden_bytes);
}
......@@ -1533,14 +1535,14 @@ combine(int4* combined_x, float* combined_topk_weights,
// Issue TMA store
tma_store_fence();
__syncwarp();
if (lane_id == 0)
if (elect_one_sync())
tma_store_1d(tma_buffer, shifted_x_buffers, num_bytes_per_token, false);
}
lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
}
// Move queue tail
tma_store_wait();
tma_store_wait<0>();
__syncwarp();
if (lane_id < kNumRDMARanks and is_lane_ready)
st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
......@@ -1597,7 +1599,6 @@ combine(int4* combined_x, float* combined_topk_weights,
uint32_t tma_phase[kNumStages] = {0};
if (lane_id < kNumStages) {
mbarrier_init(tma_mbarrier(lane_id), 32);
fence_view_async_shared();
fence_barrier_init();
}
__syncwarp();
......@@ -1700,7 +1701,7 @@ combine(int4* combined_x, float* combined_topk_weights,
// Write new RDMA tail
__syncwarp();
if (lane_id == 0) {
if (elect_one_sync()) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
}
......@@ -1709,7 +1710,7 @@ combine(int4* combined_x, float* combined_topk_weights,
// Retired
__syncwarp();
if (lane_id == 0)
if (elect_one_sync())
forwarder_retired[warp_id] = true;
} else if (warp_role == WarpRole::kRDMAReceiver) {
// Receive from RDMA ranks and write to the output tensor
......@@ -1749,8 +1750,12 @@ combine(int4* combined_x, float* combined_topk_weights,
__syncwarp();
// Combine current token
auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx;};
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* {
return reinterpret_cast<int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx;
};
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float {
return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);
};
uint32_t dummy_tma_phases[2];
combine_token<kNumRDMARanks, true, dtype_t, kNumTopkRDMARanks, false, 2>(expected_head >= 0,
expected_head, lane_id,
......@@ -1765,7 +1770,7 @@ combine(int4* combined_x, float* combined_topk_weights,
// Retired
__syncwarp();
if (lane_id == 0)
if (elect_one_sync())
rdma_receiver_retired[warp_id] = true;
} else {
// Coordinator
......
......@@ -645,7 +645,6 @@ combine(void* combined_x,
// Initialize m-barriers
if (lane_id < kNumStages) {
mbarrier_init(full_barriers[lane_id], 1);
fence_view_async_shared();
fence_barrier_init();
}
__syncwarp();
......@@ -678,7 +677,7 @@ combine(void* combined_x,
const auto cpy_dst_int4_ptr = dst_p2p_ptr == 0 ? reinterpret_cast<int4*>(buf_ptr) : reinterpret_cast<int4*>(dst_p2p_ptr);
// Prefetch
if (elect_one_sync(lane_id))
if (elect_one_sync())
tma_load_and_arrive(0, cpy_src_int4_ptr, get_num_tma_bytes(0));
__syncwarp();
......@@ -688,7 +687,7 @@ combine(void* combined_x,
// Load the next iteration
const int& stage_idx = iter_idx % kNumStages;
const int& next_stage_idx = (iter_idx + 1) % kNumStages;
if (iter_idx + 1 < kNumIters and elect_one_sync(lane_id)) {
if (iter_idx + 1 < kNumIters and elect_one_sync()) {
tma_store_wait<kNumStages - kNumPrefetch - 1>();
const auto& offset_int4 = i + 32 * kNumSendUnrolls;
tma_load_and_arrive(next_stage_idx, cpy_src_int4_ptr + offset_int4, get_num_tma_bytes(offset_int4));
......@@ -706,12 +705,12 @@ combine(void* combined_x,
// NOTES: only the leader lane will write the result
(i % kNumInt4PerDivision == 0) ? meta_buffers + i / kNumInt4PerDivision : nullptr,
lane_id);
if (elect_one_sync(lane_id))
if (elect_one_sync())
tma_store_1d(tma_buffers[stage_idx], reinterpret_cast<uint8_t*>(cpy_dst_int4_ptr) + tma_offset_bytes, num_tma_bytes);
tma_offset_bytes += num_tma_bytes;
} else {
// BF16 original values
if (elect_one_sync(lane_id))
if (elect_one_sync())
tma_store_1d(tma_buffers[stage_idx], cpy_dst_int4_ptr + i, get_num_tma_bytes(i));
}
__syncwarp();
......@@ -720,12 +719,12 @@ combine(void* combined_x,
// Store metadata (min/max values) for LogFMT
if constexpr (kUseLogFMT) {
num_send_bytes = tma_offset_bytes;
if (elect_one_sync(lane_id))
if (elect_one_sync())
tma_store_1d(meta_buffers, cpy_dst_int4_ptr, kNumMetaBytes);
}
// Flush all stores
tma_store_wait();
tma_store_wait<0>();
__syncwarp();
}
......@@ -754,7 +753,6 @@ combine(void* combined_x,
// Destroy m-barriers
if (lane_id < kNumStages) {
mbarrier_inval(full_barriers[lane_id]);
fence_view_async_shared();
fence_barrier_init();
}
__syncwarp();
......@@ -842,7 +840,7 @@ combine(void* combined_x,
buffer, reinterpret_cast<float2*>(log_amax_buffers[stage_idx]),
reinterpret_cast<float2*>(log_amin_buffers[stage_idx]), cast_info_buffers[stage_idx], lane_id);
}
if (elect_one_sync(lane_id)) {
if (elect_one_sync()) {
int num_casted = 0;
if constexpr (kUseLogFMT) {
const auto& info = cast_info_buffers[stage_idx][num_decode_warps - 1];
......@@ -891,7 +889,7 @@ combine(void* combined_x,
);
}
if (elect_one_sync(lane_id))
if (elect_one_sync())
mbarrier_arrive(empty_barriers[stage_idx]);
stage_idx = (stage_idx + 1) % kNumStages;
}
......@@ -903,7 +901,7 @@ combine(void* combined_x,
tma_st_buffers[decode_warp_idx][kNumRecvUnrolls * 4 * lane_id + k] = *reinterpret_cast<uint32_t*>(&combined_pack);
}
tma_store_fence();
if (elect_one_sync(lane_id)) {
if (elect_one_sync()) {
tma_store_1d(tma_st_buffers[decode_warp_idx],
static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4 + decode_warp_idx * kNumRecvUnrolls * 32,
kNumBF16PerWarpBytes);
......@@ -911,9 +909,6 @@ combine(void* combined_x,
__syncwarp();
}
}
// Flush all stores
tma_store_wait<0>();
}
}
......
......@@ -89,7 +89,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32)
count += is_token_in_rank[i * kNumRanks + dst_rank];
count = warp_reduce_sum(count);
if (lane_id == 0)
if (elect_one_sync())
channel_prefix_matrix[dst_rank * num_channels + channel_id] = count;
}
__syncthreads();
......@@ -228,9 +228,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + half_hidden_bytes);
uint32_t tma_phase = 0;
if (lane_id == 0) {
if (elect_one_sync()) {
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 and half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);
}
......@@ -248,7 +247,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
// NOTES: this is for distinguishing zero tokens
if (lane_id == 0 and send_warp_id_in_rank == 0) {
if (send_warp_id_in_rank == 0 and elect_one_sync()) {
int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] : 0;
st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);
value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel];
......@@ -266,7 +265,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
// NOTES: the head index received by different warps may not be the same
auto start_time = clock64();
while (lane_id == 0) {
if (elect_one_sync()) {
while (true) {
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens)
......@@ -278,12 +278,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
trap();
}
}
}
__syncwarp();
int chunk_token_idx = 0;
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send the following data
if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
if (token_idx % num_send_warps_per_rank == send_warp_id_in_rank and elect_one_sync())
send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1;
// Skip if not selected
......@@ -301,7 +302,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, __ldg, st_na_global);
// Copy source index
if (lane_id == 0)
if (elect_one_sync())
channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);
// Copy `topk_idx` and `topk_weights` with transformed index
......@@ -333,7 +334,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Move tail index
// NOTES: here all warps should share the same new tail
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
if (send_warp_id_in_rank == 0 and lane_id == 0)
if (send_warp_id_in_rank == 0 and elect_one_sync())
st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);
}
} else {
......@@ -352,9 +353,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Receive channel offset
int total_offset, num_tokens_to_recv;
while (lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
while (lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
if (lane_id == 0) {
if (elect_one_sync()) {
while ((total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
while ((num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;
if (recv_warp_id_in_rank == 0)
recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset;
......@@ -398,14 +399,16 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
#ifndef DISABLE_SM90_FEATURES
if (elect_one_sync()) {
#pragma unroll
for (int i = 0; i < 2; ++ i) if (lane_id == 0) {
tma_store_wait();
for (int i = 0; i < 2; ++ i) {
tma_store_wait<0>();
tma_load_1d(tma_buffer, shifted_buffer_x_int4 + i * half_hidden_int4, tma_mbarrier, half_hidden_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, half_hidden_bytes);
mbarrier_wait(tma_mbarrier, tma_phase);
tma_store_1d(tma_buffer, shifted_recv_x_int4 + i * half_hidden_int4, half_hidden_bytes, false);
}
}
__syncwarp();
#else
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4,
......@@ -442,21 +445,14 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
cached_channel_head_idx += num_recv_tokens;
total_offset += num_recv_tokens;
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and lane_id == 0)
if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and elect_one_sync())
st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);
// Exit
num_tokens_to_recv -= num_recv_tokens;
}
// Make TMA store visible to the next kernel
#ifndef DISABLE_SM90_FEATURES
if (lane_id == 0)
tma_store_wait();
#endif
}
// Clean unused `recv_topk_idx` as -1
if (num_worst_tokens > 0) {
auto rank_prefix_matrix = static_cast<int*>(buffer_ptrs[rank]);
......@@ -647,7 +643,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
auto start_time = clock64();
int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx));
while (lane_id == 0) {
if (elect_one_sync()) {
while (true) {
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens)
......@@ -659,6 +656,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
trap();
}
}
}
__syncwarp();
// Send by chunk
......@@ -673,7 +671,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
UNROLLED_WARP_COPY(4, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
// Send source index
if (lane_id == 0)
if (elect_one_sync())
channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i);
// Send `topk_weights`
......@@ -685,7 +683,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
// Move tail index
asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank));
if (lane_id == 0 and send_warp_id_in_rank == 0)
if (send_warp_id_in_rank == 0 and elect_one_sync())
st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx);
}
} else {
......@@ -793,8 +791,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
// Wait shared memory release
#ifndef DISABLE_SM90_FEATURES
if (lane_id == 0)
tma_store_wait();
if (elect_one_sync())
tma_store_wait<0>();
__syncwarp();
#endif
......@@ -840,7 +838,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
#ifndef DISABLE_SM90_FEATURES
// Wait TMA arrival
if (lane_id == 0)
if (elect_one_sync())
tma_store_wait<kNumStages - 1>();
__syncwarp();
......@@ -851,7 +849,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
// Issue TMA
tma_store_fence();
__syncwarp();
if (lane_id == 0) {
if (elect_one_sync()) {
auto tma_bytes = min(32, hidden_int4 - i) * static_cast<int>(sizeof(int4));
tma_store_1d(reinterpret_cast<int4*>(tma_buffer) + tma_stage_idx * 32,
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
......@@ -878,14 +876,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
// Retired
__syncwarp();
if (lane_id == 0)
if (elect_one_sync())
warp_retired[recv_warp_id] = true;
// Make TMA store visible to the next kernel
#ifndef DISABLE_SM90_FEATURES
if (lane_id == 0)
tma_store_wait();
#endif
}
}
}
......
......@@ -304,27 +304,32 @@ __device__ __forceinline__ float exp2f_approx(const float &x) {
return ret;
}
// TMA PTX instructions
#ifndef DISABLE_SM90_FEATURES
__forceinline__ __device__ int get_lane_id() {
int lane_id;
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
__device__ __forceinline__ uint32_t elect_one_sync(int lane_id) {
__device__ __forceinline__ uint32_t elect_one_sync() {
#ifndef DISABLE_SM90_FEATURES
uint32_t pred = 0;
asm volatile(
"{\n"
".reg .b32 %%rx;\n"
".reg .pred %%px;\n"
" elect.sync %%rx|%%px, %2;\n"
"@%%px mov.s32 %1, 1;\n"
" mov.s32 %0, %%rx;\n"
" elect.sync %%rx|%%px, %1;\n"
"@%%px mov.s32 %0, 1;\n"
"}\n"
: "+r"(lane_id), "+r"(pred)
: "+r"(pred)
: "r"(0xffffffff));
return pred;
#else
return get_lane_id() == 0;
#endif
}
__device__ __forceinline__ void fence_view_async_shared() {
asm volatile("fence.proxy.async.shared::cta; \n" :: );
}
// TMA PTX instructions
#ifndef DISABLE_SM90_FEATURES
__device__ __forceinline__ void fence_barrier_init() {
asm volatile("fence.mbarrier_init.release.cluster; \n" :: );
......@@ -390,7 +395,7 @@ __device__ __forceinline__ void tma_store_1d(const void* smem_ptr, const void* g
asm volatile("cp.async.bulk.commit_group;");
}
template <int N = 0>
template <int N>
__device__ __forceinline__ void tma_store_wait() {
asm volatile("cp.async.bulk.wait_group.read %0;" :: "n"(N) : "memory");
}
......@@ -441,12 +446,6 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t*>(recv_int_values);
}
__forceinline__ __device__ int get_lane_id() {
int lane_id;
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 448.0f;
constexpr float kFinfoAmaxInvE4M3 = 1 / 448.0f;
......
......@@ -204,9 +204,13 @@ def test_main(args: argparse.Namespace, num_sms: int,
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: '
f'{notify_t * 1e6:.0f} + {t * 1e6:.0f} us, '
f'{rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: '
f'{best_results[3] * 1e6:.0f} + {best_time * 1e6:.0f} us, '
f'{rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print('', flush=True)
if isinstance(current_x, tuple):
......@@ -230,12 +234,17 @@ def test_main(args: argparse.Namespace, num_sms: int,
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.combine(**tune_args), ('combine', 'notify'), suppress_kineto_output=True)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: '
f'{notify_t * 1e6:.0f} + {t * 1e6:.0f} us, '
f'{combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), '
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t)
if local_rank == 0:
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, '
f'{best_results[3] * 1e6:.2f} + {best_time * 1e6:.2f} us, '
f'{combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print('', flush=True)
return hash_value
......
......@@ -189,7 +189,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True)
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), {t * 1e6:.2f} us', flush=True)
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True)
print('', flush=True)
......@@ -220,7 +220,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True)
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), {t * 1e6:.2f} us', flush=True)
if t < best_time and nvl_chunk_size > 0:
best_time, best_results = t, (num_sms, nvl_chunk_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment