Commit 1b00b9d8 authored by lishen's avatar lishen
Browse files

Merge branch 'updates' into 'main'

Updates

See merge request dcutoolkit/deeplearing/DeepEP!10
parents 7e8acdf7 4f828c59
......@@ -150,6 +150,8 @@ struct LowLatencyLayout {
size_t num_bytes_per_dispatch_msg =
sizeof(int4) +
std::max(hidden * sizeof(hip_bfloat16), hidden + num_scales * sizeof(float));
// 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16);
// Send buffer
......@@ -176,7 +178,8 @@ struct LowLatencyLayout {
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
total_bytes += signaling_buffer_bytes * 2;
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
total_bytes += signaling_buffer_bytes_aligned * 2;
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
......@@ -185,15 +188,15 @@ struct LowLatencyLayout {
buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int64_t)),
// dispatch:send_buffer + recv_buffer + recv_count
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int64_t*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int64_t*>(rdma_buffer, signaling_buffer_bytes_aligned * i),
// combine:send_buffer + recv_buffer + recv_count
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int64_t*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int64_t*>(rdma_buffer, signaling_buffer_bytes_aligned * i),
// combine_rdma_send_buffer_data_start
advance(rdma_buffer, send_buffer_bytes * i + sizeof(int4)),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
//
num_bytes_per_combine_msg
};
......
......@@ -1397,6 +1397,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) {
......@@ -1418,6 +1419,13 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
if (combine_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous());
EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks);
}
auto hidden = static_cast<int>(x.size(2));
auto num_local_experts = num_experts / num_ranks, num_topk = static_cast<int>(topk_weights.size(1));
auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
......@@ -1456,6 +1464,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
global_atomic_counter.data_ptr<int>(),
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
......
......@@ -183,6 +183,7 @@ public:
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);
......
......@@ -155,6 +155,7 @@ void combine(void* combined_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
......
......@@ -549,6 +549,7 @@ combine(void* combined_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int,
int* atomic_clean_flag,
int num_combined_tokens, int hidden, int num_topk,
......@@ -572,7 +573,7 @@ combine(void* combined_x,
// Message package
EP_STATIC_ASSERT(kHidden % FP8_QUANTIZATION_NUM_PER_CHANNEL == 0, "Invalid hidden");
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16);
constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// 16 is the max possible number of warps in AMD GPUs
......@@ -627,12 +628,12 @@ combine(void* combined_x,
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
// Copy directly to local rank, or copy to buffer and issue RDMA
const auto src_idx = __ldg(local_src_info + token_idx);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
if (dst_rank == rank) {
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
......@@ -724,8 +725,23 @@ LOW_LATENCY_COMBINE_RECV:
// Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) {
EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 0 and lane_id == 0){
while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0);
if (sub_warp_id == 0 and lane_id == 0) {
const auto src_rank = responsible_expert_idx / num_local_experts;
auto start_time = wall_clock64();
uint64_t wait_recv_cost = 0;
while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0 // recv not ready
&& (wait_recv_cost = wall_clock64() - start_time) <= NUM_TIMEOUT_CYCLES // not timeout
);
// Mask rank if timeout
if (wait_recv_cost > NUM_TIMEOUT_CYCLES) {
printf("Warning: DeepEP timeout for combine receive, rank %d, local_expert_idx %d, src_rank %d\n",
rank, responsible_expert_idx % num_local_experts, src_rank);
}
if (combine_wait_recv_cost_stats != nullptr) {
atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
}
}
}
grid_barrier(global_atomic_counter, num_sms);
......@@ -750,7 +766,7 @@ LOW_LATENCY_COMBINE_RECV:
// Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) +
(reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
// Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
......@@ -776,6 +792,7 @@ void combine(void* combined_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
......@@ -803,6 +820,7 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, \
combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
......
......@@ -901,7 +901,8 @@ class Buffer:
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]:
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
......@@ -927,6 +928,9 @@ class Buffer:
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
......@@ -935,6 +939,7 @@ class Buffer:
"""
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts,
zero_copy, async_finish, return_recv_hook, out)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
......
......@@ -140,7 +140,7 @@ def test_main(num_tokens: int,
topk_weights,
handle,
async_finish=not return_recv_hook,
# zero_copy=zero_copy,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
hook() if return_recv_hook else event.current_stream_wait()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment