Unverified Commit 4b67064d authored by sky's avatar sky Committed by GitHub
Browse files

Add diagnosis module for efficient and precise location of slow rank (#311)



* Add diagnosis module for precise identification of slow ranks
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* Add tests for the slow diagnosis module
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* Update some comments for diagnose
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* Update test case for diagnose
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* Strip the diagnose module, thx LyricZhao and sphish.
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* update variable name and cumulative wait recv cost, thx sphish.
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

* remove invalid comments.
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>

---------
Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>
parent b92d0d48
......@@ -1090,6 +1090,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook) {
......@@ -1110,6 +1111,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks);
}
if (dispatch_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous());
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks);
}
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_topk = static_cast<int>(topk_idx.size(1));
auto num_local_experts = num_experts / num_ranks;
......@@ -1162,6 +1169,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
......@@ -1200,6 +1208,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) {
......@@ -1222,6 +1231,13 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
if (combine_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous());
EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks);
}
auto hidden = static_cast<int>(x.size(2));
auto num_topk = static_cast<int>(topk_weights.size(1));
auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
......@@ -1259,6 +1275,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
......
......@@ -146,6 +146,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook);
......@@ -153,6 +154,7 @@ public:
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);
......
......@@ -143,6 +143,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
......@@ -156,6 +157,7 @@ void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
......
......@@ -42,6 +42,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
......@@ -272,7 +273,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int num_recv_tokens, recv_token_begin_idx;
EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15);
if (sub_warp_id == 1 and lane_id == 0) {
auto start_time = clock64();
while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
auto wait_recv_cost = clock64() - start_time;
num_recv_tokens = -num_recv_tokens - 1;
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
......@@ -280,6 +283,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
if (cumulative_local_expert_recv_stats != nullptr)
atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);
if (dispatch_wait_recv_cost_stats != nullptr)
atomicAdd(reinterpret_cast<unsigned long long*>(dispatch_wait_recv_cost_stats + src_rank),
wait_recv_cost);
}
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
......@@ -330,6 +337,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
......@@ -368,6 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
cumulative_local_expert_recv_stats, \
dispatch_wait_recv_cost_stats, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
......@@ -388,6 +397,7 @@ combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
int* atomic_clean_flag,
int num_combined_tokens, int hidden, int num_topk,
......@@ -618,7 +628,12 @@ combine(void* combined_x,
if (responsible_expert_idx < num_experts) {
EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 0 and lane_id == 0) {
auto start_time = clock64();
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
auto wait_recv_cost = clock64() - start_time;
if (combine_wait_recv_cost_stats != nullptr)
atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats
+ responsible_expert_idx / num_local_experts), wait_recv_cost);
}
}
cg::this_grid().sync();
......@@ -667,6 +682,7 @@ void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
......@@ -701,6 +717,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
......
......@@ -515,6 +515,7 @@ class Buffer:
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
......@@ -535,6 +536,9 @@ class Buffer:
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring.
dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
......@@ -565,6 +569,7 @@ class Buffer:
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx,
cumulative_local_expert_recv_stats,
dispatch_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, round_scale, use_ue8m0,
async_finish, return_recv_hook)
......@@ -579,7 +584,8 @@ class Buffer:
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]:
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
......@@ -605,6 +611,9 @@ class Buffer:
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you do not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`.
......@@ -613,6 +622,7 @@ class Buffer:
"""
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts,
use_logfmt, zero_copy, async_finish, return_recv_hook,
out)
......
import argparse
import random
import time
import os
import torch
import torch.distributed as dist
import numpy as np
from functools import partial
from typing import Optional
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
......@@ -10,7 +14,7 @@ from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_to
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer,
use_logfmt: bool = False, seed: int = 0):
use_logfmt: bool = False, seed: int = 0, enable_diagnose: bool = False):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
......@@ -121,6 +125,23 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
use_logfmt=use_logfmt, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# noinspection PyShadowingNames
def test_diagnose(test_dispatch_slow: bool, slow_rank: int,
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None):
if test_dispatch_slow:
if rank == slow_rank:
time.sleep(0.001)
buffer.low_latency_dispatch(x_pure_rand, topk_idx, num_tokens, num_experts,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
dispatch_wait_recv_cost_stats=dispatch_wait_recv_cost_stats,
use_fp8=True, async_finish=False)
else:
if rank == slow_rank:
time.sleep(0.001)
buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
use_logfmt=use_logfmt, return_recv_hook=False,
combine_wait_recv_cost_stats=combine_wait_recv_cost_stats)
# Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
......@@ -146,6 +167,83 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True)
# Diagnose test
if enable_diagnose:
def diagnose_matrix(
mat, thres_col=3.0, thres_row=3.0, thres_point=5.0,
suppress_points_in_strong_rowscols=True
):
"""
mat: 2D numpy array, mat[i, j] = the waiting time of src i waiting for dst j to receive the token
Returns abnormal columns/rows/points.
suppress_points_in_strong_rowscols: whether to remove points located in already detected abnormal rows or columns
"""
# 1. Check for abnormal columns
col_means = mat.mean(axis=0)
# z_col = (col_means - col_means.mean()) / (col_means.std() + 1e-8)
z_col = col_means / (col_means.mean() + 1e-8)
abnormal_cols = np.where(z_col > thres_col)[0].tolist()
# 2. Check for abnormal rows
row_means = mat.mean(axis=1)
# z_row = (row_means - row_means.mean()) / (row_means.std() + 1e-8)
z_row = row_means / (row_means.mean() + 1e-8)
abnormal_rows = np.where(z_row > thres_row)[0].tolist()
# 3. Check for abnormal single points
# z_all = (mat - mat.mean()) / (mat.std() + 1e-8)
z_all = mat / (mat.mean() + 1e-8)
# Get all positions with z-score > threshold
abnormal_points = [
(i, j, mat[i, j], z_all[i, j])
for i in range(mat.shape[0])
for j in range(mat.shape[1])
if z_all[i, j] > thres_point
]
# Optionally remove points that are in already detected abnormal rows
# or columns
if suppress_points_in_strong_rowscols:
abnormal_points = [
(i, j, v, z) for (i, j, v, z) in abnormal_points
if i not in abnormal_rows and j not in abnormal_cols
]
# 4. Return for automatic processing
return {
'abnormal_cols': abnormal_cols,
'abnormal_rows': abnormal_rows,
'abnormal_points': abnormal_points
}
dispatch_wait_recv_cost_stats = torch.zeros((num_ranks, ), dtype=torch.int64, device='cuda')
combine_wait_recv_cost_stats = torch.zeros((num_ranks, ), dtype=torch.int64, device='cuda')
slow_rank = [0, 1]
for i, test_dispatch_slow in enumerate([True, False]):
bench(
partial(
test_diagnose,
test_dispatch_slow=test_dispatch_slow,
slow_rank=slow_rank[i],
dispatch_wait_recv_cost_stats=dispatch_wait_recv_cost_stats,
combine_wait_recv_cost_stats=combine_wait_recv_cost_stats))
stats_list = [dispatch_wait_recv_cost_stats, combine_wait_recv_cost_stats]
stats_tensor = torch.stack(stats_list, dim=0) # (N, num_ranks)
# gather all ranks dispatch and combine diagnose stats to rank 0
gather_tensor = [
torch.zeros_like(
torch.stack(
stats_list,
dim=0)) for _ in range(
group.size())] if rank == 0 else None
dist.gather(stats_tensor, gather_list=gather_tensor, group=group, dst=0)
if rank == 0:
stats_arr = torch.stack([it.cpu() for it in gather_tensor], dim=0).numpy()
for i, name in enumerate(["Dispatch", "Combine"]):
res = diagnose_matrix(stats_arr[:, i, :])
assert slow_rank[i] in res[
'abnormal_cols'], f"[Diagnose] test failure, slow_rank {slow_rank[i]} not found in abnormal_cols {res['abnormal_cols']}"
print(
f'[Diagnose] test successful!!! [{name}] slow_rank: {slow_rank[i]} diagnose info: {res}')
return hash_value
......@@ -162,7 +260,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=1)
use_logfmt=args.use_logfmt, seed=1, enable_diagnose=args.enable_diagnose)
do_pressure_test = args.pressure_test
for seed in range(int(1e9) if do_pressure_test else 0):
......@@ -200,6 +298,8 @@ if __name__ == '__main__':
help='Whether to test LogFMT combine')
parser.add_argument("--pressure-test", action='store_true',
help='Whether to do pressure test')
parser.add_argument('--enable-diagnose', action='store_true',
help='Whether to enable diagnose for testing')
args = parser.parse_args()
num_processes = args.num_processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment