Commit 6cc3497d authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Remove all raw tensors for better P2P overlapping

parent f6030640
...@@ -93,7 +93,6 @@ struct LowLatencyBuffer { ...@@ -93,7 +93,6 @@ struct LowLatencyBuffer {
void* dispatch_rdma_send_buffer = nullptr; void* dispatch_rdma_send_buffer = nullptr;
void* dispatch_rdma_recv_data_buffer = nullptr; void* dispatch_rdma_recv_data_buffer = nullptr;
int* dispatch_rdma_recv_count_buffer = nullptr; int* dispatch_rdma_recv_count_buffer = nullptr;
int* dispatch_rdma_atomic_token_counter = nullptr;
void* combine_rdma_send_buffer = nullptr; void* combine_rdma_send_buffer = nullptr;
void* combine_rdma_recv_data_buffer = nullptr; void* combine_rdma_recv_data_buffer = nullptr;
...@@ -145,10 +144,8 @@ struct LowLatencyLayout { ...@@ -145,10 +144,8 @@ struct LowLatencyLayout {
// Symmetric signaling buffers // Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t dispatch_recv_atomic_token_counter_bytes = num_local_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes + dispatch_recv_atomic_token_counter_bytes, size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
combine_recv_flag_buffer_bytes);
total_bytes += signaling_buffer_bytes * 2; total_bytes += signaling_buffer_bytes * 2;
// Assign pointers // Assign pointers
...@@ -160,7 +157,6 @@ struct LowLatencyLayout { ...@@ -160,7 +157,6 @@ struct LowLatencyLayout {
advance(rdma_buffer, send_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i + dispatch_recv_count_buffer_bytes),
advance(rdma_buffer, send_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i) advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i)
......
...@@ -1048,8 +1048,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1048,8 +1048,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(torch::kFloat8_e4m3fn)); auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(torch::kFloat8_e4m3fn));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::from_blob(buffer.dispatch_rdma_atomic_token_counter, auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
{num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Allocate column-majored scales // Allocate column-majored scales
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
...@@ -1061,6 +1060,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1061,6 +1060,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto launcher = [=](int phases) { auto launcher = [=](int phases) {
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales.data_ptr<float>(), internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales.data_ptr<float>(),
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(), packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer, buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), x.data_ptr(), topk_idx.data_ptr<int64_t>(),
......
...@@ -132,6 +132,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, ...@@ -132,6 +132,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
void dispatch(void* packed_recv_x, float* packed_recv_x_scales, void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx, const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
......
...@@ -40,9 +40,10 @@ template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden> ...@@ -40,9 +40,10 @@ template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
dispatch(void* packed_recv_x, float* packed_recv_x_scales, dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx, const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int* atomic_counter_per_local_expert, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank, int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
...@@ -215,6 +216,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, ...@@ -215,6 +216,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Clean workspace for next use // Clean workspace for next use
atomic_counter_per_expert[responsible_expert_idx] = 0; atomic_counter_per_expert[responsible_expert_idx] = 0;
atomic_finish_counter_per_expert[responsible_expert_idx] = 0; atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
// Clean `packed_recv_count`
if (dst_rank == 0)
packed_recv_count[dst_expert_local_idx] = 0;
} }
__syncwarp(); __syncwarp();
...@@ -223,6 +228,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, ...@@ -223,6 +228,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
if ((phases & LOW_LATENCY_RECV_PHASE) == 0) if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return; return;
// For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
if (phases & LOW_LATENCY_SEND_PHASE)
cg::this_grid().sync();
// Receiving and packing // Receiving and packing
if (responsible_expert_idx < num_experts) { if (responsible_expert_idx < num_experts) {
const auto src_rank = responsible_expert_idx / num_local_experts; const auto src_rank = responsible_expert_idx / num_local_experts;
...@@ -252,7 +261,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, ...@@ -252,7 +261,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0); while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
} }
num_recv_tokens = -num_recv_tokens - 1; num_recv_tokens = -num_recv_tokens - 1;
recv_token_begin_idx = atomicAdd(atomic_counter_per_local_expert + local_expert_idx, num_recv_tokens); recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens; shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx); recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
...@@ -290,6 +299,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, ...@@ -290,6 +299,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void dispatch(void* packed_recv_x, float* packed_recv_x_scales, void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx, const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
...@@ -311,17 +321,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, ...@@ -311,17 +321,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
// Use the last part `rdma_recv_count` as `atomic_counter_per_local_expert`
// NOTES: this part will be cleaned in `combine`
auto atomic_counter_per_local_expert = rdma_recv_count + num_ranks * (num_experts / num_ranks);
#define DISPATCH_LAUNCH_CASE(hidden) \ #define DISPATCH_LAUNCH_CASE(hidden) \
LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \ LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
packed_recv_x, packed_recv_x_scales, \ packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \ packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
rdma_recv_x, rdma_recv_count, rdma_x, \ rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \ x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, atomic_counter_per_local_expert, \ atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \ next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \ num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, phases); break num_topk, num_experts, rank, num_ranks, phases); break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment