"examples/python-examples/input.prmtop" did not exist on "534fd40416011ebbbdf49deebecd32bcddaaffae"
Commit 4f828c59 authored by lishen's avatar lishen
Browse files

支持combine_wait_recv_cost记录

parent f4b3020e
...@@ -1397,6 +1397,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1397,6 +1397,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool zero_copy, bool async, bool return_recv_hook, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) { const std::optional<torch::Tensor>& out) {
...@@ -1418,6 +1419,13 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1418,6 +1419,13 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
if (combine_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous());
EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks);
}
auto hidden = static_cast<int>(x.size(2)); auto hidden = static_cast<int>(x.size(2));
auto num_local_experts = num_experts / num_ranks, num_topk = static_cast<int>(topk_weights.size(1)); auto num_local_experts = num_experts / num_ranks, num_topk = static_cast<int>(topk_weights.size(1));
auto num_combined_tokens = static_cast<int>(topk_weights.size(0)); auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
...@@ -1456,6 +1464,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1456,6 +1464,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(), x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(), src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
global_atomic_counter.data_ptr<int>(), global_atomic_counter.data_ptr<int>(),
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, num_topk, num_experts, rank, num_ranks,
......
...@@ -183,6 +183,7 @@ public: ...@@ -183,6 +183,7 @@ public:
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool zero_copy, bool async, bool return_recv_hook, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt); const std::optional<torch::Tensor>& out = std::nullopt);
......
...@@ -155,6 +155,7 @@ void combine(void* combined_x, ...@@ -155,6 +155,7 @@ void combine(void* combined_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range, const int* src_info, const int64_t* layout_range,
int* global_atomic_counter, int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
......
...@@ -549,6 +549,7 @@ combine(void* combined_x, ...@@ -549,6 +549,7 @@ combine(void* combined_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range, const int* src_info, const int64_t* layout_range,
int* global_atomic_counter, int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int* atomic_clean_flag, int* atomic_clean_flag,
int num_combined_tokens, int hidden, int num_topk, int num_combined_tokens, int hidden, int num_topk,
...@@ -724,8 +725,23 @@ LOW_LATENCY_COMBINE_RECV: ...@@ -724,8 +725,23 @@ LOW_LATENCY_COMBINE_RECV:
// Wait all ranks to arrive and notify PCIe usage // Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) { if (responsible_expert_idx < num_experts) {
EP_DEVICE_ASSERT(num_warps_per_group > 1); EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 0 and lane_id == 0){ if (sub_warp_id == 0 and lane_id == 0) {
while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0); const auto src_rank = responsible_expert_idx / num_local_experts;
auto start_time = wall_clock64();
uint64_t wait_recv_cost = 0;
while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0 // recv not ready
&& (wait_recv_cost = wall_clock64() - start_time) <= NUM_TIMEOUT_CYCLES // not timeout
);
// Mask rank if timeout
if (wait_recv_cost > NUM_TIMEOUT_CYCLES) {
printf("Warning: DeepEP timeout for combine receive, rank %d, local_expert_idx %d, src_rank %d\n",
rank, responsible_expert_idx % num_local_experts, src_rank);
}
if (combine_wait_recv_cost_stats != nullptr) {
atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
}
} }
} }
grid_barrier(global_atomic_counter, num_sms); grid_barrier(global_atomic_counter, num_sms);
...@@ -776,6 +792,7 @@ void combine(void* combined_x, ...@@ -776,6 +792,7 @@ void combine(void* combined_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range, const int* src_info, const int64_t* layout_range,
int* global_atomic_counter, int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
...@@ -803,6 +820,7 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \ ...@@ -803,6 +820,7 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \ rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \ x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, \ global_atomic_counter, \
combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \ next_clean, num_next_clean_int, \
atomic_clean_flag, \ atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \ num_combined_tokens, hidden, num_topk, \
......
...@@ -901,7 +901,8 @@ class Buffer: ...@@ -901,7 +901,8 @@ class Buffer:
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, zero_copy: bool = False, async_finish: bool = False, handle: tuple, zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \ return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]: Tuple[torch.Tensor, EventOverlap, Callable]:
""" """
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
...@@ -927,6 +928,9 @@ class Buffer: ...@@ -927,6 +928,9 @@ class Buffer:
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival. If you not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
Returns: Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`. combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
...@@ -935,6 +939,7 @@ class Buffer: ...@@ -935,6 +939,7 @@ class Buffer:
""" """
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
zero_copy, async_finish, return_recv_hook, out) zero_copy, async_finish, return_recv_hook, out)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment