Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
7705f533
Commit
7705f533
authored
Jul 02, 2025
by
Chenggang Zhao
Browse files
Refactor testing arguments
parent
6b17f4fa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
61 deletions
+40
-61
tests/test_internode.py
tests/test_internode.py
+21
-24
tests/test_intranode.py
tests/test_intranode.py
+12
-17
tests/test_low_latency.py
tests/test_low_latency.py
+7
-20
No files found.
tests/test_internode.py
View file @
7705f533
import
argparse
import
os
import
os
import
time
import
time
import
torch
import
torch
...
@@ -11,13 +12,13 @@ from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_un
...
@@ -11,13 +12,13 @@ from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_un
import
test_low_latency
import
test_low_latency
def
test_main
(
num_sms
:
int
,
local_rank
:
int
,
num_local_ranks
:
int
,
num_ranks
:
int
,
num_nodes
:
int
,
rank
:
int
,
buffer
:
deep_ep
.
Buffer
,
group
:
dist
.
ProcessGroup
,
args
):
# noinspection PyShadowingNames
def
test_main
(
args
:
argparse
.
Namespace
,
num_sms
:
int
,
local_rank
:
int
,
num_local_ranks
:
int
,
num_ranks
:
int
,
num_nodes
:
int
,
rank
:
int
,
buffer
:
deep_ep
.
Buffer
,
group
:
dist
.
ProcessGroup
):
# Settings
# Settings
num_tokens
=
args
.
num_tokens
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
hidden
=
args
.
hidden
num_topk_groups
,
num_topk
,
num_experts
=
args
.
num_topk_groups
,
args
.
num_topk
,
args
.
num_experts
num_topk_groups
=
args
.
num_topk_groups
num_topk
=
args
.
num_topk
num_experts
=
args
.
num_experts
assert
num_experts
%
num_ranks
==
0
and
num_local_ranks
==
8
assert
num_experts
%
num_ranks
==
0
and
num_local_ranks
==
8
if
local_rank
==
0
:
if
local_rank
==
0
:
...
@@ -223,29 +224,28 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
...
@@ -223,29 +224,28 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
print
(
''
,
flush
=
True
)
print
(
''
,
flush
=
True
)
# noinspection PyUnboundLocalVariable
# noinspection PyUnboundLocalVariable
,PyShadowingNames
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
):
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
:
argparse
.
Namespace
):
num_nodes
=
int
(
os
.
getenv
(
'WORLD_SIZE'
,
1
))
num_nodes
=
int
(
os
.
getenv
(
'WORLD_SIZE'
,
1
))
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
test_ll_compatibility
=
os
.
getenv
(
'EP_TEST_LL_COMPATIBILITY'
,
False
)
if
args
.
test_ll_compatibility
:
if
test_ll_compatibility
:
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
num_sms
=
24
num_sms
=
24
num_qps_per_rank
=
max
(
num_sms
,
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
0
)
num_qps_per_rank
=
max
(
num_sms
,
ll_num_experts
//
num_ranks
if
args
.
test_ll_compatibility
else
0
)
buffer
=
deep_ep
.
Buffer
(
group
,
int
(
1e9
),
int
(
1e9
),
low_latency_mode
=
test_ll_compatibility
,
buffer
=
deep_ep
.
Buffer
(
group
,
int
(
1e9
),
int
(
1e9
),
low_latency_mode
=
args
.
test_ll_compatibility
,
num_qps_per_rank
=
num_qps_per_rank
)
num_qps_per_rank
=
num_qps_per_rank
)
assert
num_local_ranks
==
8
and
num_ranks
>
8
assert
num_local_ranks
==
8
and
num_ranks
>
8
torch
.
manual_seed
(
rank
)
torch
.
manual_seed
(
rank
)
for
i
in
(
num_sms
,
):
for
i
in
(
num_sms
,
):
test_main
(
i
,
local_rank
,
num_local_ranks
,
num_ranks
,
num_nodes
,
rank
,
buffer
,
group
,
args
)
test_main
(
args
,
i
,
local_rank
,
num_local_ranks
,
num_ranks
,
num_nodes
,
rank
,
buffer
,
group
)
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
''
,
flush
=
True
)
print
(
''
,
flush
=
True
)
# Test compatibility with low latency functions
# Test compatibility with low latency functions
if
test_ll_compatibility
:
if
args
.
test_ll_compatibility
:
buffer
.
clean_low_latency_buffer
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
)
buffer
.
clean_low_latency_buffer
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
)
test_low_latency
.
test_main
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
test_low_latency
.
test_main
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
...
@@ -255,8 +255,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
...
@@ -255,8 +255,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
'Test internode EP kernels'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Test internode expert parallel'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
4096
,
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
4096
,
...
@@ -264,21 +263,19 @@ if __name__ == '__main__':
...
@@ -264,21 +263,19 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--hidden'
,
type
=
int
,
default
=
7168
,
parser
.
add_argument
(
'--hidden'
,
type
=
int
,
default
=
7168
,
help
=
'Hidden dimension size (default: 7168)'
)
help
=
'Hidden dimension size (default: 7168)'
)
parser
.
add_argument
(
'--num-topk-groups'
,
type
=
int
,
default
=
None
,
parser
.
add_argument
(
'--num-topk-groups'
,
type
=
int
,
default
=
None
,
help
=
'Number of top-k groups (default: min(num_nodes, 4))'
)
help
=
'Number of top-k groups (default:
`
min(num_nodes, 4)
`
)'
)
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
help
=
'Number of top-k experts (default: 8)'
)
help
=
'Number of top-k experts (default: 8)'
)
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
None
,
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
256
,
help
=
'Number of experts (default: calculated as (256 // num_ranks) * num_ranks)'
)
help
=
'Number of experts (default: 256'
)
parser
.
add_argument
(
'--test-ll-compatibility'
,
action
=
'store_true'
,
help
=
'whether to test compatibility with low-latency kernels'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# Set default num_topk_groups if not provided
# Set default
`
num_topk_groups
`
if not provided
if
args
.
num_topk_groups
is
None
:
if
args
.
num_topk_groups
is
None
:
num_nodes
=
int
(
os
.
getenv
(
'WORLD_SIZE'
,
1
))
num_nodes
=
int
(
os
.
getenv
(
'WORLD_SIZE'
,
1
))
args
.
num_topk_groups
=
min
(
num_nodes
,
4
)
args
.
num_topk_groups
=
min
(
num_nodes
,
4
)
# Set default num_experts if not provided
if
args
.
num_experts
is
None
:
args
.
num_experts
=
(
256
//
args
.
num_processes
)
*
args
.
num_processes
num_processes
=
args
.
num_processes
num_processes
=
args
.
num_processes
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
tests/test_intranode.py
View file @
7705f533
import
os
import
argparse
import
time
import
time
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -11,12 +11,12 @@ from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to
...
@@ -11,12 +11,12 @@ from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to
import
test_low_latency
import
test_low_latency
def
test_main
(
num_sms
:
int
,
local_rank
:
int
,
num_ranks
:
int
,
rank
:
int
,
buffer
:
deep_ep
.
Buffer
,
group
:
dist
.
ProcessGroup
,
args
):
# noinspection PyShadowingNames
def
test_main
(
args
:
argparse
.
Namespace
,
num_sms
:
int
,
local_rank
:
int
,
num_ranks
:
int
,
rank
:
int
,
buffer
:
deep_ep
.
Buffer
,
group
:
dist
.
ProcessGroup
):
# Settings
# Settings
num_tokens
=
args
.
num_tokens
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
hidden
=
args
.
hidden
num_topk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
num_topk
=
args
.
num_topk
num_experts
=
args
.
num_experts
assert
num_experts
%
num_ranks
==
0
assert
num_experts
%
num_ranks
==
0
if
local_rank
==
0
:
if
local_rank
==
0
:
...
@@ -229,8 +229,8 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
...
@@ -229,8 +229,8 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
print
(
''
,
flush
=
True
)
print
(
''
,
flush
=
True
)
# noinspection PyUnboundLocalVariable
# noinspection PyUnboundLocalVariable
,PyShadowingNames
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
):
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
:
argparse
.
Namespace
):
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
test_ll_compatibility
,
num_rdma_bytes
=
False
,
0
test_ll_compatibility
,
num_rdma_bytes
=
False
,
0
if
test_ll_compatibility
:
if
test_ll_compatibility
:
...
@@ -242,7 +242,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
...
@@ -242,7 +242,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
torch
.
manual_seed
(
rank
)
torch
.
manual_seed
(
rank
)
for
i
in
(
24
,
):
for
i
in
(
24
,
):
test_main
(
i
,
local_rank
,
num_ranks
,
rank
,
buffer
,
group
,
args
)
test_main
(
args
,
i
,
local_rank
,
num_ranks
,
rank
,
buffer
,
group
)
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
''
,
flush
=
True
)
print
(
''
,
flush
=
True
)
...
@@ -257,8 +257,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
...
@@ -257,8 +257,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
'Test intranode EP kernels'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Test intranode expert parallel'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
4096
,
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
4096
,
...
@@ -267,13 +266,9 @@ if __name__ == '__main__':
...
@@ -267,13 +266,9 @@ if __name__ == '__main__':
help
=
'Hidden dimension size (default: 7168)'
)
help
=
'Hidden dimension size (default: 7168)'
)
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
help
=
'Number of top-k experts (default: 8)'
)
help
=
'Number of top-k experts (default: 8)'
)
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
None
,
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
256
,
help
=
'Number of experts (default:
calculated as (256 // num_ranks) * num_ranks
)'
)
help
=
'Number of experts (default:
256
)'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# Set default num_experts if not provided
if
args
.
num_experts
is
None
:
args
.
num_experts
=
(
256
//
args
.
num_processes
)
*
args
.
num_processes
num_processes
=
args
.
num_processes
num_processes
=
args
.
num_processes
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
tests/test_low_latency.py
View file @
7705f533
import
os
import
argparse
import
random
import
random
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -16,7 +16,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -16,7 +16,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
assert
num_experts
%
num_ranks
==
0
assert
num_experts
%
num_ranks
==
0
num_local_experts
=
num_experts
//
num_ranks
num_local_experts
=
num_experts
//
num_ranks
# NOTES: the integers greater than 256 exceed
s
the BF16 precision limit
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset
=
128
rank_offset
=
128
assert
num_ranks
-
rank_offset
<
257
,
'Too many ranks (exceeding test precision limit)'
assert
num_ranks
-
rank_offset
<
257
,
'Too many ranks (exceeding test precision limit)'
...
@@ -98,16 +98,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -98,16 +98,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
assert
diff
<
(
7e-4
if
round_scale
else
1e-5
),
f
'Error:
{
diff
=
}
,
{
zero_copy
=
}
'
assert
diff
<
(
7e-4
if
round_scale
else
1e-5
),
f
'Error:
{
diff
=
}
,
{
zero_copy
=
}
'
hash_value
^=
hash_tensor
(
combined_x
)
hash_value
^=
hash_tensor
(
combined_x
)
def
create_test_cast_with_outliers
(
num_outliers
):
tmp
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
tmp
/=
tmp
.
abs
().
amax
(
dim
=
1
).
view
(
-
1
,
1
)
assert
tmp
.
abs
().
amax
().
item
()
<=
1
# Create some amax outliers
for
i
in
range
(
num_outliers
):
tmp
[
random
.
randint
(
0
,
num_tokens
-
1
)]
*=
1e3
return
tmp
# noinspection PyShadowingNames
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
def
large_gemm_with_hook
(
hook
):
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
...
@@ -156,13 +146,11 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -156,13 +146,11 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
return
hash_value
return
hash_value
# noinspection PyUnboundLocalVariable
# noinspection PyUnboundLocalVariable
,PyShadowingNames
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
):
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
:
argparse
.
Namespace
):
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
num_tokens
=
args
.
num_tokens
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
hidden
=
args
.
hidden
num_topk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
num_topk
=
args
.
num_topk
num_experts
=
args
.
num_experts
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
if
local_rank
==
0
:
if
local_rank
==
0
:
...
@@ -186,8 +174,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
...
@@ -186,8 +174,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: you may modify NUMA binding for less CPU overhead
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low-latency EP kernels'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low latency expert parallel'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
128
,
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
128
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment