You need to sign in or sign up before continuing.
Commit f4b3020e authored by lishen's avatar lishen
Browse files

支持zero_copy正确性

parent 7e8acdf7
...@@ -150,6 +150,8 @@ struct LowLatencyLayout { ...@@ -150,6 +150,8 @@ struct LowLatencyLayout {
size_t num_bytes_per_dispatch_msg = size_t num_bytes_per_dispatch_msg =
sizeof(int4) + sizeof(int4) +
std::max(hidden * sizeof(hip_bfloat16), hidden + num_scales * sizeof(float)); std::max(hidden * sizeof(hip_bfloat16), hidden + num_scales * sizeof(float));
// 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16); size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16);
// Send buffer // Send buffer
...@@ -176,7 +178,8 @@ struct LowLatencyLayout { ...@@ -176,7 +178,8 @@ struct LowLatencyLayout {
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t); size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
total_bytes += signaling_buffer_bytes * 2; size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
total_bytes += signaling_buffer_bytes_aligned * 2;
// Assign pointers // Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer, // NOTES: we still leave some space for distinguishing dispatch/combine buffer,
...@@ -185,15 +188,15 @@ struct LowLatencyLayout { ...@@ -185,15 +188,15 @@ struct LowLatencyLayout {
buffers[i] = { buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int64_t)), static_cast<int>(signaling_buffer_bytes / sizeof(int64_t)),
// dispatch:send_buffer + recv_buffer + recv_count // dispatch:send_buffer + recv_buffer + recv_count
advance(rdma_buffer, send_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int64_t*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), advance<int64_t*>(rdma_buffer, signaling_buffer_bytes_aligned * i),
// combine:send_buffer + recv_buffer + recv_count // combine:send_buffer + recv_buffer + recv_count
advance(rdma_buffer, send_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int64_t*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), advance<int64_t*>(rdma_buffer, signaling_buffer_bytes_aligned * i),
// combine_rdma_send_buffer_data_start // combine_rdma_send_buffer_data_start
advance(rdma_buffer, send_buffer_bytes * i + sizeof(int4)), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
// //
num_bytes_per_combine_msg num_bytes_per_combine_msg
}; };
......
...@@ -572,7 +572,7 @@ combine(void* combined_x, ...@@ -572,7 +572,7 @@ combine(void* combined_x,
// Message package // Message package
EP_STATIC_ASSERT(kHidden % FP8_QUANTIZATION_NUM_PER_CHANNEL == 0, "Invalid hidden"); EP_STATIC_ASSERT(kHidden % FP8_QUANTIZATION_NUM_PER_CHANNEL == 0, "Invalid hidden");
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16); constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// 16 is the max possible number of warps in AMD GPUs // 16 is the max possible number of warps in AMD GPUs
...@@ -627,12 +627,12 @@ combine(void* combined_x, ...@@ -627,12 +627,12 @@ combine(void* combined_x,
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) { for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
const auto x_int4 = local_x + token_idx * hidden_bf16_int4; const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot); const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4); const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
// Copy directly to local rank, or copy to buffer and issue RDMA // Copy directly to local rank, or copy to buffer and issue RDMA
const auto src_idx = __ldg(local_src_info + token_idx); const auto src_idx = __ldg(local_src_info + token_idx);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row); const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4); const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
if (dst_rank == rank) { if (dst_rank == rank) {
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr); const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
...@@ -750,7 +750,7 @@ LOW_LATENCY_COMBINE_RECV: ...@@ -750,7 +750,7 @@ LOW_LATENCY_COMBINE_RECV:
// Read from sources // Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) +
(reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot); (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4); auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
// Reduce // Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id); auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
......
...@@ -140,7 +140,7 @@ def test_main(num_tokens: int, ...@@ -140,7 +140,7 @@ def test_main(num_tokens: int,
topk_weights, topk_weights,
handle, handle,
async_finish=not return_recv_hook, async_finish=not return_recv_hook,
# zero_copy=zero_copy, zero_copy=zero_copy,
return_recv_hook=return_recv_hook, return_recv_hook=return_recv_hook,
out=out) out=out)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment