Unverified Commit 4b67064d authored by sky's avatar sky Committed by GitHub
Browse files

Add diagnosis module for efficient and precise location of slow rank (#311)



* Add diagnosis module for precise identification of slow ranks
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* Add tests for the slow diagnosis module
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* Update some comments for diagnose
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* Update test case for diagnose
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* Strip the diagnose module, thx LyricZhao and sphish.
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* update variable name and cumulative wait recv cost, thx sphish.
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* remove invalid comments.
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

---------
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>
parent b92d0d48
...@@ -1090,6 +1090,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int ...@@ -1090,6 +1090,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats, const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook) { bool async, bool return_recv_hook) {
...@@ -1110,6 +1111,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1110,6 +1111,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks); EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks);
} }
if (dispatch_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous());
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks);
}
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1)); auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_topk = static_cast<int>(topk_idx.size(1)); auto num_topk = static_cast<int>(topk_idx.size(1));
auto num_local_experts = num_experts / num_ranks; auto num_local_experts = num_experts / num_ranks;
...@@ -1162,6 +1169,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1162,6 +1169,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(), packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(), packed_recv_count.data_ptr<int>(),
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr, cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer, buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), x.data_ptr(), topk_idx.data_ptr<int64_t>(),
...@@ -1200,6 +1208,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1200,6 +1208,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) { const std::optional<torch::Tensor>& out) {
...@@ -1222,6 +1231,13 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1222,6 +1231,13 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
if (combine_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous());
EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks);
}
auto hidden = static_cast<int>(x.size(2)); auto hidden = static_cast<int>(x.size(2));
auto num_topk = static_cast<int>(topk_weights.size(1)); auto num_topk = static_cast<int>(topk_weights.size(1));
auto num_combined_tokens = static_cast<int>(topk_weights.size(0)); auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
...@@ -1259,6 +1275,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1259,6 +1275,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
buffer.combine_rdma_send_buffer, buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(), x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(), src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, num_topk, num_experts, rank, num_ranks,
......
...@@ -146,6 +146,7 @@ public: ...@@ -146,6 +146,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats, const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook); bool async, bool return_recv_hook);
...@@ -153,6 +154,7 @@ public: ...@@ -153,6 +154,7 @@ public:
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt); const std::optional<torch::Tensor>& out = std::nullopt);
......
...@@ -143,6 +143,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -143,6 +143,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count, int* packed_recv_count,
int* cumulative_local_expert_recv_stats, int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx, const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
...@@ -156,6 +157,7 @@ void combine(void* combined_x, ...@@ -156,6 +157,7 @@ void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range, const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
......
...@@ -42,6 +42,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -42,6 +42,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count, int* packed_recv_count,
int* cumulative_local_expert_recv_stats, int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx, const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
...@@ -272,7 +273,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -272,7 +273,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int num_recv_tokens, recv_token_begin_idx; int num_recv_tokens, recv_token_begin_idx;
EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15); EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15);
if (sub_warp_id == 1 and lane_id == 0) { if (sub_warp_id == 1 and lane_id == 0) {
auto start_time = clock64();
while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0); while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
auto wait_recv_cost = clock64() - start_time;
num_recv_tokens = -num_recv_tokens - 1; num_recv_tokens = -num_recv_tokens - 1;
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens; shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
...@@ -280,6 +283,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -280,6 +283,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx); recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
if (cumulative_local_expert_recv_stats != nullptr) if (cumulative_local_expert_recv_stats != nullptr)
atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens); atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);
if (dispatch_wait_recv_cost_stats != nullptr)
atomicAdd(reinterpret_cast<unsigned long long*>(dispatch_wait_recv_cost_stats + src_rank),
wait_recv_cost);
} }
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32)); asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
num_recv_tokens = shared_num_recv_tokens[warp_group_id]; num_recv_tokens = shared_num_recv_tokens[warp_group_id];
...@@ -330,6 +337,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -330,6 +337,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count, int* packed_recv_count,
int* cumulative_local_expert_recv_stats, int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx, const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
...@@ -368,6 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \ ...@@ -368,6 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_src_info, packed_recv_layout_range, \ packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \ packed_recv_count, \
cumulative_local_expert_recv_stats, \ cumulative_local_expert_recv_stats, \
dispatch_wait_recv_cost_stats, \
rdma_recv_x, rdma_recv_count, rdma_x, \ rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \ x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \ atomic_counter_per_expert, atomic_finish_counter_per_expert, \
...@@ -388,6 +397,7 @@ combine(void* combined_x, ...@@ -388,6 +397,7 @@ combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range, const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
int* atomic_clean_flag, int* atomic_clean_flag,
int num_combined_tokens, int hidden, int num_topk, int num_combined_tokens, int hidden, int num_topk,
...@@ -618,7 +628,12 @@ combine(void* combined_x, ...@@ -618,7 +628,12 @@ combine(void* combined_x,
if (responsible_expert_idx < num_experts) { if (responsible_expert_idx < num_experts) {
EP_DEVICE_ASSERT(num_warps_per_group > 1); EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 0 and lane_id == 0) { if (sub_warp_id == 0 and lane_id == 0) {
auto start_time = clock64();
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0); while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
auto wait_recv_cost = clock64() - start_time;
if (combine_wait_recv_cost_stats != nullptr)
atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats
+ responsible_expert_idx / num_local_experts), wait_recv_cost);
} }
} }
cg::this_grid().sync(); cg::this_grid().sync();
...@@ -667,6 +682,7 @@ void combine(void* combined_x, ...@@ -667,6 +682,7 @@ void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range, const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
...@@ -701,6 +717,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \ ...@@ -701,6 +717,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \ combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \ rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \ x, topk_idx, topk_weights, src_info, layout_range, \
combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \ next_clean, num_next_clean_int, \
atomic_clean_flag, \ atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \ num_combined_tokens, hidden, num_topk, \
......
...@@ -515,6 +515,7 @@ class Buffer: ...@@ -515,6 +515,7 @@ class Buffer:
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int, num_max_dispatch_tokens_per_rank: int, num_experts: int,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False, use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \ async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
...@@ -535,6 +536,9 @@ class Buffer: ...@@ -535,6 +536,9 @@ class Buffer:
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance `[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring. monitoring.
dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2. round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`). use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
...@@ -565,6 +569,7 @@ class Buffer: ...@@ -565,6 +569,7 @@ class Buffer:
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \ packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx, self.runtime.low_latency_dispatch(x, topk_idx,
cumulative_local_expert_recv_stats, cumulative_local_expert_recv_stats,
dispatch_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, round_scale, use_ue8m0, use_fp8, round_scale, use_ue8m0,
async_finish, return_recv_hook) async_finish, return_recv_hook)
...@@ -579,7 +584,8 @@ class Buffer: ...@@ -579,7 +584,8 @@ class Buffer:
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False, handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \ return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]: Tuple[torch.Tensor, EventOverlap, Callable]:
""" """
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
...@@ -605,6 +611,9 @@ class Buffer: ...@@ -605,6 +611,9 @@ class Buffer:
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you do not set this flag, the kernel will ensure the data's arrival. If you do not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
Returns: Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`. combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`.
...@@ -613,6 +622,7 @@ class Buffer: ...@@ -613,6 +622,7 @@ class Buffer:
""" """
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
use_logfmt, zero_copy, async_finish, return_recv_hook, use_logfmt, zero_copy, async_finish, return_recv_hook,
out) out)
......
import argparse import argparse
import random import random
import time
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import numpy as np
from functools import partial from functools import partial
from typing import Optional
import deep_ep import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
...@@ -10,7 +14,7 @@ from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_to ...@@ -10,7 +14,7 @@ from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_to
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer,
use_logfmt: bool = False, seed: int = 0): use_logfmt: bool = False, seed: int = 0, enable_diagnose: bool = False):
torch.manual_seed(seed + rank) torch.manual_seed(seed + rank)
random.seed(seed + rank) random.seed(seed + rank)
...@@ -121,6 +125,23 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -121,6 +125,23 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
use_logfmt=use_logfmt, return_recv_hook=return_recv_hook) use_logfmt=use_logfmt, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
# noinspection PyShadowingNames
def test_diagnose(test_dispatch_slow: bool, slow_rank: int,
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None):
if test_dispatch_slow:
if rank == slow_rank:
time.sleep(0.001)
buffer.low_latency_dispatch(x_pure_rand, topk_idx, num_tokens, num_experts,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
dispatch_wait_recv_cost_stats=dispatch_wait_recv_cost_stats,
use_fp8=True, async_finish=False)
else:
if rank == slow_rank:
time.sleep(0.001)
buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
use_logfmt=use_logfmt, return_recv_hook=False,
combine_wait_recv_cost_stats=combine_wait_recv_cost_stats)
# Calculate bandwidth # Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
...@@ -146,6 +167,83 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -146,6 +167,83 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
else: else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | ' print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True) f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True)
# Diagnose test
if enable_diagnose:
def diagnose_matrix(
mat, thres_col=3.0, thres_row=3.0, thres_point=5.0,
suppress_points_in_strong_rowscols=True
):
"""
mat: 2D numpy array, mat[i, j] = the waiting time of src i waiting for dst j to receive the token
Returns abnormal columns/rows/points.
suppress_points_in_strong_rowscols: whether to remove points located in already detected abnormal rows or columns
"""
# 1. Check for abnormal columns
col_means = mat.mean(axis=0)
# z_col = (col_means - col_means.mean()) / (col_means.std() + 1e-8)
z_col = col_means / (col_means.mean() + 1e-8)
abnormal_cols = np.where(z_col > thres_col)[0].tolist()
# 2. Check for abnormal rows
row_means = mat.mean(axis=1)
# z_row = (row_means - row_means.mean()) / (row_means.std() + 1e-8)
z_row = row_means / (row_means.mean() + 1e-8)
abnormal_rows = np.where(z_row > thres_row)[0].tolist()
# 3. Check for abnormal single points
# z_all = (mat - mat.mean()) / (mat.std() + 1e-8)
z_all = mat / (mat.mean() + 1e-8)
# Get all positions with z-score > threshold
abnormal_points = [
(i, j, mat[i, j], z_all[i, j])
for i in range(mat.shape[0])
for j in range(mat.shape[1])
if z_all[i, j] > thres_point
]
# Optionally remove points that are in already detected abnormal rows
# or columns
if suppress_points_in_strong_rowscols:
abnormal_points = [
(i, j, v, z) for (i, j, v, z) in abnormal_points
if i not in abnormal_rows and j not in abnormal_cols
]
# 4. Return for automatic processing
return {
'abnormal_cols': abnormal_cols,
'abnormal_rows': abnormal_rows,
'abnormal_points': abnormal_points
}
dispatch_wait_recv_cost_stats = torch.zeros((num_ranks, ), dtype=torch.int64, device='cuda')
combine_wait_recv_cost_stats = torch.zeros((num_ranks, ), dtype=torch.int64, device='cuda')
slow_rank = [0, 1]
for i, test_dispatch_slow in enumerate([True, False]):
bench(
partial(
test_diagnose,
test_dispatch_slow=test_dispatch_slow,
slow_rank=slow_rank[i],
dispatch_wait_recv_cost_stats=dispatch_wait_recv_cost_stats,
combine_wait_recv_cost_stats=combine_wait_recv_cost_stats))
stats_list = [dispatch_wait_recv_cost_stats, combine_wait_recv_cost_stats]
stats_tensor = torch.stack(stats_list, dim=0) # (N, num_ranks)
# gather all ranks dispatch and combine diagnose stats to rank 0
gather_tensor = [
torch.zeros_like(
torch.stack(
stats_list,
dim=0)) for _ in range(
group.size())] if rank == 0 else None
dist.gather(stats_tensor, gather_list=gather_tensor, group=group, dst=0)
if rank == 0:
stats_arr = torch.stack([it.cpu() for it in gather_tensor], dim=0).numpy()
for i, name in enumerate(["Dispatch", "Combine"]):
res = diagnose_matrix(stats_arr[:, i, :])
assert slow_rank[i] in res[
'abnormal_cols'], f"[Diagnose] test failure, slow_rank {slow_rank[i]} not found in abnormal_cols {res['abnormal_cols']}"
print(
f'[Diagnose] test successful!!! [{name}] slow_rank: {slow_rank[i]} diagnose info: {res}')
return hash_value return hash_value
...@@ -162,7 +260,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -162,7 +260,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=num_experts // num_ranks, num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True) allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=1) use_logfmt=args.use_logfmt, seed=1, enable_diagnose=args.enable_diagnose)
do_pressure_test = args.pressure_test do_pressure_test = args.pressure_test
for seed in range(int(1e9) if do_pressure_test else 0): for seed in range(int(1e9) if do_pressure_test else 0):
...@@ -200,6 +298,8 @@ if __name__ == '__main__': ...@@ -200,6 +298,8 @@ if __name__ == '__main__':
help='Whether to test LogFMT combine') help='Whether to test LogFMT combine')
parser.add_argument("--pressure-test", action='store_true', parser.add_argument("--pressure-test", action='store_true',
help='Whether to do pressure test') help='Whether to do pressure test')
parser.add_argument('--enable-diagnose', action='store_true',
help='Whether to enable diagnose for testing')
args = parser.parse_args() args = parser.parse_args()
num_processes = args.num_processes num_processes = args.num_processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment