Unverified Commit 146b013d authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Optimize `cached_notify` by TMA (#306)

* Fix rdma head movement

* Optimize `cached_notify` by using TMA.

* Fix

* Small fix
parent 3073a2c6
...@@ -1027,7 +1027,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* ...@@ -1027,7 +1027,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
#undef DISPATCH_LAUNCH_CASE #undef DISPATCH_LAUNCH_CASE
} }
template <bool kLowLatencyMode> template <bool kLowLatencyMode, int kNumTMABytesPerWarp>
__global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean,
const int nvl_clean_offset, const int nvl_num_int_clean, const int nvl_clean_offset, const int nvl_num_int_clean,
int* combined_rdma_head, int num_combined_tokens, int num_channels, int* combined_rdma_head, int num_combined_tokens, int num_channels,
...@@ -1102,7 +1102,24 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in ...@@ -1102,7 +1102,24 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr); EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr);
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers"); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers");
if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) { if (warp_id < num_channels) {
constexpr int tma_batch_size = kNumTMABytesPerWarp - sizeof(uint64_t);
constexpr int num_bytes_per_token = sizeof(int) * NUM_MAX_NVL_PEERS;
constexpr int num_tokens_per_batch = tma_batch_size / num_bytes_per_token;
EP_STATIC_ASSERT(num_bytes_per_token % 16 == 0, "num_bytes_per_token should be divisible by 16");
// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto tma_buffer = smem_tma_buffer + warp_id * kNumTMABytesPerWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + tma_batch_size);
uint32_t tma_phase = 0;
if (lane_id == 0) {
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
}
__syncwarp();
for (int dst_rdma_rank = sm_id - 2; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 2) { for (int dst_rdma_rank = sm_id - 2; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 2) {
// Iterate in reverse order // Iterate in reverse order
int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
...@@ -1112,14 +1129,33 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in ...@@ -1112,14 +1129,33 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
// NOTES: `1 << 25` is a heuristic large number // NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25; int last_head = 1 << 25;
#pragma unroll for (int batch_end_idx = token_end_idx; batch_end_idx > token_start_idx; batch_end_idx -= num_tokens_per_batch) {
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) { auto batch_start_idx = max(token_start_idx, batch_end_idx - num_tokens_per_batch);
auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
if (current_head < 0) { if (lane_id == 0) {
combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1; tma_load_1d(tma_buffer, combined_nvl_head + batch_start_idx * NUM_MAX_NVL_PEERS, tma_mbarrier, (batch_end_idx - batch_start_idx) * num_bytes_per_token);
} else { mbarrier_arrive_and_expect_tx(tma_mbarrier, (batch_end_idx - batch_start_idx) * num_bytes_per_token);
last_head = current_head;
} }
mbarrier_wait(tma_mbarrier, tma_phase);
__syncwarp();
for (int token_idx = batch_end_idx - 1; token_idx >= batch_start_idx; -- token_idx) {
if (lane_id < NUM_MAX_NVL_PEERS) {
auto current_head = reinterpret_cast<int*>(tma_buffer)[(token_idx - batch_start_idx) * NUM_MAX_NVL_PEERS + lane_id];
if (current_head < 0) {
reinterpret_cast<int*>(tma_buffer)[(token_idx - batch_start_idx) * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
} else {
last_head = current_head;
}
}
}
tma_store_fence();
__syncwarp();
if (lane_id == 0)
tma_store_1d(tma_buffer, combined_nvl_head + batch_start_idx * NUM_MAX_NVL_PEERS, (batch_end_idx - batch_start_idx) * num_bytes_per_token);
tma_store_wait();
__syncwarp();
} }
} }
} }
...@@ -1135,7 +1171,10 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -1135,7 +1171,10 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
int64_t num_rdma_bytes, int64_t num_nvl_bytes, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode) { bool is_cached_dispatch, bool low_latency_mode) {
const int num_threads = std::max(128, 32 * num_channels); const int num_threads = std::max(128, 32 * num_channels);
const int num_warps = num_threads / 32;
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int kNumTMABytesPerWarp = 8192;
const int smem_size = kNumTMABytesPerWarp * num_warps;
// Get clean meta // Get clean meta
auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels);
...@@ -1147,8 +1186,9 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -1147,8 +1186,9 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
EP_HOST_ASSERT(num_channels * 2 > 3); EP_HOST_ASSERT(num_channels * 2 > 3);
// Launch kernel // Launch kernel
auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>; auto cached_notify_func = low_latency_mode ? cached_notify<true, kNumTMABytesPerWarp> : cached_notify<false, kNumTMABytesPerWarp>;
SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream); SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
SET_SHARED_MEMORY_FOR_TMA(cached_notify_func);
LAUNCH_KERNEL(&cfg, cached_notify_func, LAUNCH_KERNEL(&cfg, cached_notify_func,
rdma_clean_meta.first, rdma_clean_meta.second, rdma_clean_meta.first, rdma_clean_meta.second,
nvl_clean_meta.first, nvl_clean_meta.second, nvl_clean_meta.first, nvl_clean_meta.second,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment