Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
c7033854
Commit
c7033854
authored
Jul 31, 2025
by
Chenggang Zhao
Browse files
Remove the diagnosis part from tests
parent
be8053d6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
98 deletions
+2
-98
tests/test_low_latency.py
tests/test_low_latency.py
+2
-98
No files found.
tests/test_low_latency.py
View file @
c7033854
...
...
@@ -14,7 +14,7 @@ from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_to
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
rank
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
use_logfmt
:
bool
=
False
,
seed
:
int
=
0
,
enable_diagnose
:
bool
=
False
):
use_logfmt
:
bool
=
False
,
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
...
...
@@ -125,23 +125,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
use_logfmt
=
use_logfmt
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# noinspection PyShadowingNames
def
test_diagnose
(
test_dispatch_slow
:
bool
,
slow_rank
:
int
,
dispatch_wait_recv_cost_stats
:
Optional
[
torch
.
Tensor
]
=
None
,
combine_wait_recv_cost_stats
:
Optional
[
torch
.
Tensor
]
=
None
):
if
test_dispatch_slow
:
if
rank
==
slow_rank
:
time
.
sleep
(
0.001
)
buffer
.
low_latency_dispatch
(
x_pure_rand
,
topk_idx
,
num_tokens
,
num_experts
,
cumulative_local_expert_recv_stats
=
cumulative_local_expert_recv_stats
,
dispatch_wait_recv_cost_stats
=
dispatch_wait_recv_cost_stats
,
use_fp8
=
True
,
async_finish
=
False
)
else
:
if
rank
==
slow_rank
:
time
.
sleep
(
0.001
)
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
use_logfmt
=
use_logfmt
,
return_recv_hook
=
False
,
combine_wait_recv_cost_stats
=
combine_wait_recv_cost_stats
)
# Calculate bandwidth
num_fp8_bytes
,
num_bf16_bytes
=
(
hidden
+
hidden
/
128
*
4
+
16
),
hidden
*
2
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
...
...
@@ -167,83 +150,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
else
:
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
[
0
]
*
1e6
:.
2
f
}
+
{
dispatch_t
[
1
]
*
1e6
:.
2
f
}
us | '
f
'Combine send/recv time:
{
combine_t
[
0
]
*
1e6
:.
2
f
}
+
{
combine_t
[
1
]
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
# Diagnose test
if
enable_diagnose
:
def
diagnose_matrix
(
mat
,
thres_col
=
3.0
,
thres_row
=
3.0
,
thres_point
=
5.0
,
suppress_points_in_strong_rowscols
=
True
):
"""
mat: 2D numpy array, mat[i, j] = the waiting time of src i waiting for dst j to receive the token
Returns abnormal columns/rows/points.
suppress_points_in_strong_rowscols: whether to remove points located in already detected abnormal rows or columns
"""
# 1. Check for abnormal columns
col_means
=
mat
.
mean
(
axis
=
0
)
# z_col = (col_means - col_means.mean()) / (col_means.std() + 1e-8)
z_col
=
col_means
/
(
col_means
.
mean
()
+
1e-8
)
abnormal_cols
=
np
.
where
(
z_col
>
thres_col
)[
0
].
tolist
()
# 2. Check for abnormal rows
row_means
=
mat
.
mean
(
axis
=
1
)
# z_row = (row_means - row_means.mean()) / (row_means.std() + 1e-8)
z_row
=
row_means
/
(
row_means
.
mean
()
+
1e-8
)
abnormal_rows
=
np
.
where
(
z_row
>
thres_row
)[
0
].
tolist
()
# 3. Check for abnormal single points
# z_all = (mat - mat.mean()) / (mat.std() + 1e-8)
z_all
=
mat
/
(
mat
.
mean
()
+
1e-8
)
# Get all positions with z-score > threshold
abnormal_points
=
[
(
i
,
j
,
mat
[
i
,
j
],
z_all
[
i
,
j
])
for
i
in
range
(
mat
.
shape
[
0
])
for
j
in
range
(
mat
.
shape
[
1
])
if
z_all
[
i
,
j
]
>
thres_point
]
# Optionally remove points that are in already detected abnormal rows
# or columns
if
suppress_points_in_strong_rowscols
:
abnormal_points
=
[
(
i
,
j
,
v
,
z
)
for
(
i
,
j
,
v
,
z
)
in
abnormal_points
if
i
not
in
abnormal_rows
and
j
not
in
abnormal_cols
]
# 4. Return for automatic processing
return
{
'abnormal_cols'
:
abnormal_cols
,
'abnormal_rows'
:
abnormal_rows
,
'abnormal_points'
:
abnormal_points
}
dispatch_wait_recv_cost_stats
=
torch
.
zeros
((
num_ranks
,
),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
combine_wait_recv_cost_stats
=
torch
.
zeros
((
num_ranks
,
),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
slow_rank
=
[
0
,
1
]
for
i
,
test_dispatch_slow
in
enumerate
([
True
,
False
]):
bench
(
partial
(
test_diagnose
,
test_dispatch_slow
=
test_dispatch_slow
,
slow_rank
=
slow_rank
[
i
],
dispatch_wait_recv_cost_stats
=
dispatch_wait_recv_cost_stats
,
combine_wait_recv_cost_stats
=
combine_wait_recv_cost_stats
))
stats_list
=
[
dispatch_wait_recv_cost_stats
,
combine_wait_recv_cost_stats
]
stats_tensor
=
torch
.
stack
(
stats_list
,
dim
=
0
)
# (N, num_ranks)
# gather all ranks dispatch and combine diagnose stats to rank 0
gather_tensor
=
[
torch
.
zeros_like
(
torch
.
stack
(
stats_list
,
dim
=
0
))
for
_
in
range
(
group
.
size
())]
if
rank
==
0
else
None
dist
.
gather
(
stats_tensor
,
gather_list
=
gather_tensor
,
group
=
group
,
dst
=
0
)
if
rank
==
0
:
stats_arr
=
torch
.
stack
([
it
.
cpu
()
for
it
in
gather_tensor
],
dim
=
0
).
numpy
()
for
i
,
name
in
enumerate
([
"Dispatch"
,
"Combine"
]):
res
=
diagnose_matrix
(
stats_arr
[:,
i
,
:])
assert
slow_rank
[
i
]
in
res
[
'abnormal_cols'
],
f
"[Diagnose] test failure, slow_rank
{
slow_rank
[
i
]
}
not found in abnormal_cols
{
res
[
'abnormal_cols'
]
}
"
print
(
f
'[Diagnose] test successful!!! [
{
name
}
] slow_rank:
{
slow_rank
[
i
]
}
diagnose info:
{
res
}
'
)
return
hash_value
...
...
@@ -260,7 +166,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank
=
num_experts
//
num_ranks
,
allow_nvlink_for_low_latency_mode
=
not
args
.
disable_nvlink
,
explicitly_destroy
=
True
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
1
,
enable_diagnose
=
args
.
enable_diagnose
)
use_logfmt
=
args
.
use_logfmt
,
seed
=
1
)
do_pressure_test
=
args
.
pressure_test
for
seed
in
range
(
int
(
1e9
)
if
do_pressure_test
else
0
):
...
...
@@ -298,8 +204,6 @@ if __name__ == '__main__':
help
=
'Whether to test LogFMT combine'
)
parser
.
add_argument
(
"--pressure-test"
,
action
=
'store_true'
,
help
=
'Whether to do pressure test'
)
parser
.
add_argument
(
'--enable-diagnose'
,
action
=
'store_true'
,
help
=
'Whether to enable diagnose for testing'
)
args
=
parser
.
parse_args
()
num_processes
=
args
.
num_processes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment