Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
969e30f8
Commit
969e30f8
authored
Dec 15, 2025
by
lishen
Browse files
完善低延迟模式int8类型的测试
parent
f08e5bf1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
36 deletions
+21
-36
tests/test_low_latency_new_int8.py
tests/test_low_latency_new_int8.py
+21
-36
No files found.
tests/test_low_latency_new_int8.py
View file @
969e30f8
...
@@ -7,7 +7,7 @@ from functools import partial
...
@@ -7,7 +7,7 @@ from functools import partial
from
typing
import
Literal
,
Set
from
typing
import
Literal
,
Set
import
deep_ep
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_back
,
per_token_cast_back_int8
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_back_int8
def
test_main
(
num_tokens
:
int
,
def
test_main
(
num_tokens
:
int
,
...
@@ -68,6 +68,7 @@ def test_main(num_tokens: int,
...
@@ -68,6 +68,7 @@ def test_main(num_tokens: int,
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back_int8
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
1
)).
view
(
packed_recv_x
[
0
].
shape
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
expert_id
=
rank
*
num_local_experts
+
i
...
@@ -133,10 +134,16 @@ def test_main(num_tokens: int,
...
@@ -133,10 +134,16 @@ def test_main(num_tokens: int,
use_fp8
=
True
,
round_scale
=
False
,
use_ue8m0
=
False
,
use_int8
=
True
,
use_fp8
=
True
,
round_scale
=
False
,
use_ue8m0
=
False
,
use_int8
=
True
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# Calculate bandwidth
# Calculate bandwidth
num_fp8_bytes
,
num_bf16_bytes
=
(
hidden
+
hidden
/
128
*
4
+
16
),
hidden
*
2
scale_size
=
1
#
hidden
/ 128
num_
logfmt10
_bytes
=
hidden
*
10
/
8
+
hidden
/
128
*
4
num_
fp8_bytes
,
num_bf16
_bytes
=
(
hidden
+
scale_size
*
4
+
16
),
hidden
*
2
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
for
i
in
range
(
num_tokens
):
for
i
in
range
(
num_tokens
):
num_selections
=
(
topk_idx
[
i
]
!=
-
1
).
sum
().
item
()
num_selections
=
(
topk_idx
[
i
]
!=
-
1
).
sum
().
item
()
...
@@ -144,18 +151,20 @@ def test_main(num_tokens: int,
...
@@ -144,18 +151,20 @@ def test_main(num_tokens: int,
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
# Separate profiling
# Separate profiling
for
return_recv_hook
in
(
True
,
):
for
return_recv_hook
in
(
True
,
False
):
group
.
barrier
()
group
.
barrier
()
dispatch_t
=
bench_kineto
(
partial
(
test_func
,
return_recv_hook
=
return_recv_hook
),
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
'dispatch'
,
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
,
suppress_kineto_output
=
True
,
num_kernels_per_period
=
2
if
return_recv_hook
else
1
)
num_kernels_per_period
=
2
if
return_recv_hook
else
1
)
if
not
return_recv_hook
:
if
not
return_recv_hook
:
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us'
,
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us | '
f
'Combine bandwidth:
{
num_combine_comm_bytes
/
1e9
/
combine_t
:.
2
f
}
GB/s, avg_t=
{
combine_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
flush
=
True
)
else
:
else
:
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
[
0
]
*
1e6
:.
2
f
}
+
{
dispatch_t
[
1
]
*
1e6
:.
2
f
}
us'
,
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
[
0
]
*
1e6
:.
2
f
}
+
{
dispatch_t
[
1
]
*
1e6
:.
2
f
}
us | '
f
'Combine send/recv time:
{
combine_t
[
0
]
*
1e6
:.
2
f
}
+
{
combine_t
[
1
]
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
flush
=
True
)
return
hash_value
return
hash_value
...
@@ -178,30 +187,6 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -178,30 +187,6 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
allow_mnnvl
=
args
.
allow_mnnvl
)
allow_mnnvl
=
args
.
allow_mnnvl
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
# do_pressure_test = args.pressure_test
# for seed in range(int(1e9) if do_pressure_test else 0):
# if local_rank == 0:
# print(f'Testing with seed {seed} ...', flush=True)
# ref_hash = test_main(num_tokens,
# hidden,
# num_experts,
# num_topk,
# rank,
# num_ranks,
# group,
# buffer,
# seed=seed)
# for _ in range(20):
# assert test_main(num_tokens,
# hidden,
# num_experts,
# num_topk,
# rank,
# num_ranks,
# group,
# buffer,
# seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
# Destroy the buffer runtime and communication group
buffer
.
destroy
()
buffer
.
destroy
()
dist
.
barrier
()
dist
.
barrier
()
...
@@ -214,7 +199,7 @@ if __name__ == '__main__':
...
@@ -214,7 +199,7 @@ if __name__ == '__main__':
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low-latency EP kernels'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low-latency EP kernels'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
128
,
help
=
'Number of tokens (default: 128)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
128
,
help
=
'Number of tokens (default: 128)'
)
parser
.
add_argument
(
'--hidden'
,
type
=
int
,
default
=
2560
,
help
=
'Hidden dimension size (default: 7168)'
)
parser
.
add_argument
(
'--hidden'
,
type
=
int
,
default
=
7168
,
help
=
'Hidden dimension size (default: 7168)'
)
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
help
=
'Number of top-k experts (default: 8)'
)
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
help
=
'Number of top-k experts (default: 8)'
)
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
256
,
help
=
'Number of experts (default: 288)'
)
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
256
,
help
=
'Number of experts (default: 288)'
)
parser
.
add_argument
(
'--allow-mnnvl'
,
action
=
"store_true"
,
help
=
'Allow MNNVL for communication'
)
parser
.
add_argument
(
'--allow-mnnvl'
,
action
=
"store_true"
,
help
=
'Allow MNNVL for communication'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment