Commit c7033854 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Remove the diagnosis part from tests

parent be8053d6
...@@ -14,7 +14,7 @@ from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_to ...@@ -14,7 +14,7 @@ from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_to
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer,
use_logfmt: bool = False, seed: int = 0, enable_diagnose: bool = False): use_logfmt: bool = False, seed: int = 0):
torch.manual_seed(seed + rank) torch.manual_seed(seed + rank)
random.seed(seed + rank) random.seed(seed + rank)
...@@ -125,23 +125,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -125,23 +125,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
use_logfmt=use_logfmt, return_recv_hook=return_recv_hook) use_logfmt=use_logfmt, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
# noinspection PyShadowingNames
def test_diagnose(test_dispatch_slow: bool, slow_rank: int,
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None):
if test_dispatch_slow:
if rank == slow_rank:
time.sleep(0.001)
buffer.low_latency_dispatch(x_pure_rand, topk_idx, num_tokens, num_experts,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
dispatch_wait_recv_cost_stats=dispatch_wait_recv_cost_stats,
use_fp8=True, async_finish=False)
else:
if rank == slow_rank:
time.sleep(0.001)
buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
use_logfmt=use_logfmt, return_recv_hook=False,
combine_wait_recv_cost_stats=combine_wait_recv_cost_stats)
# Calculate bandwidth # Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
...@@ -167,83 +150,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -167,83 +150,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
else: else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | ' print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True) f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True)
# Diagnose test
if enable_diagnose:
def diagnose_matrix(
mat, thres_col=3.0, thres_row=3.0, thres_point=5.0,
suppress_points_in_strong_rowscols=True
):
"""
mat: 2D numpy array, mat[i, j] = the waiting time of src i waiting for dst j to receive the token
Returns abnormal columns/rows/points.
suppress_points_in_strong_rowscols: whether to remove points located in already detected abnormal rows or columns
"""
# 1. Check for abnormal columns
col_means = mat.mean(axis=0)
# z_col = (col_means - col_means.mean()) / (col_means.std() + 1e-8)
z_col = col_means / (col_means.mean() + 1e-8)
abnormal_cols = np.where(z_col > thres_col)[0].tolist()
# 2. Check for abnormal rows
row_means = mat.mean(axis=1)
# z_row = (row_means - row_means.mean()) / (row_means.std() + 1e-8)
z_row = row_means / (row_means.mean() + 1e-8)
abnormal_rows = np.where(z_row > thres_row)[0].tolist()
# 3. Check for abnormal single points
# z_all = (mat - mat.mean()) / (mat.std() + 1e-8)
z_all = mat / (mat.mean() + 1e-8)
# Get all positions with z-score > threshold
abnormal_points = [
(i, j, mat[i, j], z_all[i, j])
for i in range(mat.shape[0])
for j in range(mat.shape[1])
if z_all[i, j] > thres_point
]
# Optionally remove points that are in already detected abnormal rows
# or columns
if suppress_points_in_strong_rowscols:
abnormal_points = [
(i, j, v, z) for (i, j, v, z) in abnormal_points
if i not in abnormal_rows and j not in abnormal_cols
]
# 4. Return for automatic processing
return {
'abnormal_cols': abnormal_cols,
'abnormal_rows': abnormal_rows,
'abnormal_points': abnormal_points
}
dispatch_wait_recv_cost_stats = torch.zeros((num_ranks, ), dtype=torch.int64, device='cuda')
combine_wait_recv_cost_stats = torch.zeros((num_ranks, ), dtype=torch.int64, device='cuda')
slow_rank = [0, 1]
for i, test_dispatch_slow in enumerate([True, False]):
bench(
partial(
test_diagnose,
test_dispatch_slow=test_dispatch_slow,
slow_rank=slow_rank[i],
dispatch_wait_recv_cost_stats=dispatch_wait_recv_cost_stats,
combine_wait_recv_cost_stats=combine_wait_recv_cost_stats))
stats_list = [dispatch_wait_recv_cost_stats, combine_wait_recv_cost_stats]
stats_tensor = torch.stack(stats_list, dim=0) # (N, num_ranks)
# gather all ranks dispatch and combine diagnose stats to rank 0
gather_tensor = [
torch.zeros_like(
torch.stack(
stats_list,
dim=0)) for _ in range(
group.size())] if rank == 0 else None
dist.gather(stats_tensor, gather_list=gather_tensor, group=group, dst=0)
if rank == 0:
stats_arr = torch.stack([it.cpu() for it in gather_tensor], dim=0).numpy()
for i, name in enumerate(["Dispatch", "Combine"]):
res = diagnose_matrix(stats_arr[:, i, :])
assert slow_rank[i] in res[
'abnormal_cols'], f"[Diagnose] test failure, slow_rank {slow_rank[i]} not found in abnormal_cols {res['abnormal_cols']}"
print(
f'[Diagnose] test successful!!! [{name}] slow_rank: {slow_rank[i]} diagnose info: {res}')
return hash_value return hash_value
...@@ -260,7 +166,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -260,7 +166,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=num_experts // num_ranks, num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True) allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=1, enable_diagnose=args.enable_diagnose) use_logfmt=args.use_logfmt, seed=1)
do_pressure_test = args.pressure_test do_pressure_test = args.pressure_test
for seed in range(int(1e9) if do_pressure_test else 0): for seed in range(int(1e9) if do_pressure_test else 0):
...@@ -298,8 +204,6 @@ if __name__ == '__main__': ...@@ -298,8 +204,6 @@ if __name__ == '__main__':
help='Whether to test LogFMT combine') help='Whether to test LogFMT combine')
parser.add_argument("--pressure-test", action='store_true', parser.add_argument("--pressure-test", action='store_true',
help='Whether to do pressure test') help='Whether to do pressure test')
parser.add_argument('--enable-diagnose', action='store_true',
help='Whether to enable diagnose for testing')
args = parser.parse_args() args = parser.parse_args()
num_processes = args.num_processes num_processes = args.num_processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment