Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
30aa7a87
Commit
30aa7a87
authored
Apr 17, 2026
by
lishen01
Browse files
完善torchrun和mpi启动的测试代码
parent
243eca85
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
114 additions
and
37 deletions
+114
-37
tests/1.sh
tests/1.sh
+4
-2
tests/2.sh
tests/2.sh
+4
-2
tests/test_intranode.py
tests/test_intranode.py
+1
-1
tests/test_low_latency.py
tests/test_low_latency.py
+4
-3
tests_mpi/test_env.sh
tests_mpi/test_env.sh
+9
-6
tests_mpi/test_internode.py
tests_mpi/test_internode.py
+9
-3
tests_mpi/test_intranode.py
tests_mpi/test_intranode.py
+4
-3
tests_mpi/test_low_latency.py
tests_mpi/test_low_latency.py
+79
-17
No files found.
tests/1.sh
View file @
30aa7a87
#!/bin/bash
# rocSHMEM
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
...
...
@@ -5,6 +6,7 @@ export ROCSHMEM_MAX_NUM_CONTEXTS=60
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_HEAP_SIZE
=
3737418240
export
ROCSHMEM_TOPO_FILE_FORCE
=
./topo.config
# NMZ使用
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
...
...
@@ -18,8 +20,8 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export
PYTHONPATH
=
$(
pwd
)
/../
# test
#
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 ./test_low_latency.py
# --pressure-test
#
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
tests/2.sh
View file @
30aa7a87
#!/bin/bash
# rocSHMEM
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
...
...
@@ -5,6 +6,7 @@ export ROCSHMEM_MAX_NUM_CONTEXTS=60
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_HEAP_SIZE
=
3737418240
export
ROCSHMEM_TOPO_FILE_FORCE
=
./topo.config
# NMZ使用
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
...
...
@@ -18,8 +20,8 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export
PYTHONPATH
=
$(
pwd
)
/../
# test
#
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 ./test_low_latency.py
# --pressure-test
#
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
tests/test_intranode.py
View file @
30aa7a87
...
...
@@ -244,7 +244,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank
=
(
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
1
),
explicitly_destroy
=
True
)
torch
.
manual_seed
(
rank
)
for
i
in
(
24
,
):
for
i
in
(
60
,
):
test_main
(
args
,
i
,
local_rank
,
num_ranks
,
rank
,
buffer
,
group
)
if
local_rank
==
0
:
print
(
''
,
flush
=
True
)
...
...
tests/test_low_latency.py
View file @
30aa7a87
...
...
@@ -52,8 +52,9 @@ def test_main(num_tokens: int,
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
if
rank
==
0
:
print
(
f
"enable_dispatch_ll_layered=
{
enable_dispatch_ll_layered
}
, enable_combine_overlap=
{
enable_combine_overlap
}
, use_logfmt=
{
use_logfmt
}
"
)
print
(
f
"enable_dispatch_ll_layered=
{
enable_dispatch_ll_layered
}
, enable_combine_overlap=
{
enable_combine_overlap
}
, use_logfmt=
{
use_logfmt
}
"
)
assert
not
(
use_logfmt
and
(
enable_dispatch_ll_layered
or
enable_combine_overlap
)),
\
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert
num_experts
%
num_ranks
==
0
...
...
@@ -144,7 +145,7 @@ def test_main(num_tokens: int,
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
if
(
enable_dispatch_ll_layered
or
enable_combine_overlap
)
:
if
enable_dispatch_ll_layered
or
enable_combine_overlap
:
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
&
int_mask
# 掩掉多余的信息
else
:
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
...
...
@@ -179,7 +180,7 @@ def test_main(num_tokens: int,
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
if
enable_combine_overlap
:
block_m
,
threshold
,
num_sms
=
64
,
10
,
3
total_num_per_expert
=
ceil_div
(
num_tokens
*
num_ranks
,
block_m
)
# 每个本地专家 总的信号数
??
total_num_per_expert
=
ceil_div
(
num_tokens
*
num_ranks
,
block_m
)
# 每个本地专家 总的信号数
comp_signal
=
torch
.
zeros
(
num_local_experts
*
total_num_per_expert
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
for
i
in
range
(
num_local_experts
):
...
...
tests_mpi/test_env.sh
View file @
30aa7a87
...
...
@@ -8,12 +8,15 @@ export PYTHONPATH=$(pwd)
# rocSHMEM
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
60
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_HEAP_SIZE
=
10
737418240
export
ROCSHMEM_HEAP_SIZE
=
3
737418240
export
ROCSHMEM_TOPO_FILE_FORCE
=
$(
pwd
)
/tests_mpi/topo.config
# NMZ使用
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# duSHMEM
export
LD_LIBRARY_PATH
=
/opt/dtk/dushmem/lib:
$LD_LIBRARY_PATH
export
DEEP_EP_DEVICE_TO_HCA_MAPPING
=
0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export
NVSHMEM_SYMMETRIC_SIZE
=
10737418240
#
# duSHMEM
#
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
#
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
#
export NVSHMEM_SYMMETRIC_SIZE=10737418240
tests_mpi/test_internode.py
View file @
30aa7a87
...
...
@@ -145,7 +145,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
# Check `topk_weights`
if
not
is_rand
:
recv_topk_weights
[
recv_topk_idx
.
eq
(
-
1
)]
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
).
expand_as
(
recv_topk_weights
)[
recv_topk_idx
.
eq
(
-
1
)]
max_weights
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
)
# Shape: [Batch, 1]
recv_topk_weights
=
torch
.
where
(
recv_topk_idx
==
-
1
,
max_weights
,
recv_topk_weights
)
check_data
(
recv_topk_weights
,
recv_gbl_rank_prefix_sum
)
# Test cached dispatch (must without top-k staffs)
...
...
@@ -203,6 +204,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
nvl_recv_bytes
=
(
dispatch_bf16_nvl_recv_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_nvl_recv_bytes
for
nvl_chunk_size
in
range
(
4
,
45
,
4
):
for
rdma_chunk_size
in
range
(
4
,
33
,
4
):
if
rdma_buffer_size
%
rdma_chunk_size
!=
0
:
continue
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
tune_args
=
{
'x'
:
current_x
,
'handle'
:
handle
,
'config'
:
config
}
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
dispatch
(
**
tune_args
),
(
'dispatch'
,
'notify'
),
suppress_kineto_output
=
True
)
...
...
@@ -235,6 +238,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
best_time
,
best_results
=
1e10
,
None
for
nvl_chunk_size
in
range
(
1
,
8
,
1
):
for
rdma_chunk_size
in
range
(
12
if
num_nodes
==
2
else
8
,
33
,
4
):
if
rdma_buffer_size
%
rdma_chunk_size
!=
0
:
continue
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
tune_args
=
{
'x'
:
recv_x
,
'handle'
:
handle
,
'config'
:
config
}
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
combine
(
**
tune_args
),
(
'combine'
,
'notify'
),
suppress_kineto_output
=
True
)
...
...
@@ -272,8 +277,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
num_rdma_bytes_ll
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
ll_num_tokens
,
ll_hidden
,
num_ranks
,
ll_num_experts
)
num_sms
=
48
num_sms
=
60
num_qps_per_rank
=
max
(
num_sms
,
ll_num_experts
//
num_ranks
if
args
.
test_ll_compatibility
else
0
)
deep_ep
.
Buffer
.
set_num_sms
(
num_sms
)
hidden_bytes
=
get_hidden_bytes
(
args
)
num_nvl_bytes
,
num_rdma_bytes
,
num_rdma_bytes_norm
=
0
,
0
,
0
...
...
@@ -299,7 +305,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
break
if
rank
==
0
:
print
(
f
'
{
ref_hash
=
}
'
)
print
(
f
'
ref_hash=
{
ref_hash
}
'
)
print
(
''
,
flush
=
True
)
for
j
in
range
(
20
):
...
...
tests_mpi/test_intranode.py
View file @
30aa7a87
...
...
@@ -119,7 +119,8 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
# Check `topk_weights`
recv_topk_weights_clone
=
recv_topk_weights
.
clone
()
if
current_x
is
not
x_pure_rand
:
recv_topk_weights
[
recv_topk_idx
.
eq
(
-
1
)]
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
).
expand_as
(
recv_topk_weights
)[
recv_topk_idx
.
eq
(
-
1
)]
max_weights
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
)
# Shape: [Batch, 1]
recv_topk_weights
=
torch
.
where
(
recv_topk_idx
==
-
1
,
max_weights
,
recv_topk_weights
)
check_data
(
recv_topk_weights
,
rank_prefix_matrix
)
# Test `num_worst_tokens != 0`
...
...
@@ -251,7 +252,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank
=
(
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
1
),
explicitly_destroy
=
True
)
torch
.
manual_seed
(
rank
)
for
i
in
(
48
,
):
for
i
in
(
60
,
):
test_main
(
args
,
i
,
local_rank
,
num_ranks
,
rank
,
buffer
,
group
)
if
local_rank
==
0
:
print
(
''
,
flush
=
True
)
...
...
@@ -269,7 +270,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Test intranode EP kernels'
)
group
=
parser
.
add_argument_group
(
title
=
'extra distributed args'
)
group
.
add_argument
(
'--rank'
,
default
=-
int
(
os
.
getenv
(
'OMPI_COMM_WORLD_RANK'
,
'0'
)),
type
=
int
,
help
=
'node rank for distributed training'
)
...
...
tests_mpi/test_low_latency.py
View file @
30aa7a87
...
...
@@ -36,6 +36,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu
assert
set
(
mask_status
.
nonzero
().
squeeze
(
-
1
).
tolist
())
==
expected_masked_ranks
def
ceil_div
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
...
...
@@ -44,11 +48,17 @@ def test_main(num_tokens: int,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
enable_dispatch_ll_layered
:
bool
=
False
,
enable_combine_overlap
:
bool
=
False
,
use_logfmt
:
bool
=
False
,
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
if
rank
==
0
:
print
(
f
"enable_dispatch_ll_layered=
{
enable_dispatch_ll_layered
}
, enable_combine_overlap=
{
enable_combine_overlap
}
, use_logfmt=
{
use_logfmt
}
"
)
assert
not
(
use_logfmt
and
(
enable_dispatch_ll_layered
or
enable_combine_overlap
)),
\
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert
num_experts
%
num_ranks
==
0
num_local_experts
=
num_experts
//
num_ranks
...
...
@@ -86,6 +96,9 @@ def test_main(num_tokens: int,
hash_value
,
num_times
=
0
,
0
for
x_i
,
current_x
in
enumerate
(
x_list
):
for
return_recv_hook
in
(
False
,
True
):
if
enable_combine_overlap
and
(
not
return_recv_hook
):
# return_recv_hook 为False 时,不能启用 overlop
continue
for
quant_type
in
(
0
,
1
,
2
,
3
,
):
# 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant
=
quant_type
>
0
for
fp8_round_scale
in
(
False
,
True
)
if
quant_type
!=
3
else
(
True
,
):
...
...
@@ -133,9 +146,14 @@ def test_main(num_tokens: int,
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
if
enable_dispatch_ll_layered
or
enable_combine_overlap
:
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
&
int_mask
# 掩掉多余的信息
else
:
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
dispatch_use_quant
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
...
...
@@ -150,6 +168,7 @@ def test_main(num_tokens: int,
if
not
fp8_round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
if
dispatch_use_quant
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
...
...
@@ -161,15 +180,38 @@ def test_main(num_tokens: int,
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
use_logfmt
=
use_logfmt
,
async_finish
=
not
return_recv_hook
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
if
enable_combine_overlap
:
block_m
,
threshold
,
num_sms
=
64
,
10
,
3
total_num_per_expert
=
ceil_div
(
num_tokens
*
num_ranks
,
block_m
)
# 每个本地专家 总的信号数
comp_signal
=
torch
.
zeros
(
num_local_experts
*
total_num_per_expert
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
for
i
in
range
(
num_local_experts
):
vaild_num
=
ceil_div
(
packed_recv_count
[
i
],
block_m
)
comp_signal
[
i
*
total_num_per_expert
:
i
*
total_num_per_expert
+
vaild_num
]
=
threshold
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
packed_recv_count
=
packed_recv_count
,
comp_signal
=
comp_signal
,
block_m
=
block_m
,
threshold
=
threshold
,
num_sms
=
num_sms
,
async_finish
=
not
return_recv_hook
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
else
:
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
use_logfmt
=
use_logfmt
,
async_finish
=
not
return_recv_hook
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
do_check
:
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
...
...
@@ -181,8 +223,10 @@ def test_main(num_tokens: int,
if
rank
==
0
:
print
(
f
"data:
{
x_i
}
, return_recv_hook:
{
return_recv_hook
}
, quant_type:
{
quant_type
}
, "
,
f
"fp8_round_scale:
{
fp8_round_scale
}
, quant_group_size:
{
quant_group_size
}
pass"
)
if
rank
==
0
:
print
(
''
,
flush
=
True
)
print
(
"deep_ep 全部正确性测试完成"
)
if
enable_dispatch_ll_layered
or
enable_combine_overlap
:
return
hash_value
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
...
...
@@ -252,9 +296,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
num_topk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
print
(
f
"num_tokens, hidden, num_ranks, num_experts =
{
num_tokens
}
,
{
hidden
}
,
{
num_ranks
}
,
{
num_experts
}
"
)
enable_dispatch_ll_layered
=
args
.
enable_dispatch_ll_layered
enable_combine_overlap
=
args
.
enable_combine_overlap
if
enable_dispatch_ll_layered
:
enable_combine_overlap
=
True
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
,
enable_dispatch_ll_layered
=
enable_dispatch_ll_layered
)
if
rank
==
0
:
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
...
...
@@ -263,7 +311,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank
=
num_experts
//
num_ranks
,
allow_nvlink_for_low_latency_mode
=
not
args
.
disable_nvlink
,
explicitly_destroy
=
True
,
allow_mnnvl
=
args
.
allow_mnnvl
)
allow_mnnvl
=
args
.
allow_mnnvl
,
enable_dispatch_ll_layered
=
enable_dispatch_ll_layered
,
enable_combine_overlap
=
enable_combine_overlap
)
print
(
"deep_ep 初始化完成"
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
...
...
@@ -273,6 +325,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
enable_dispatch_ll_layered
=
enable_dispatch_ll_layered
,
enable_combine_overlap
=
enable_combine_overlap
,
seed
=
1
)
do_pressure_test
=
args
.
pressure_test
...
...
@@ -288,6 +342,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
enable_dispatch_ll_layered
=
enable_dispatch_ll_layered
,
enable_combine_overlap
=
enable_combine_overlap
,
seed
=
seed
)
for
_
in
range
(
20
):
assert
test_main
(
num_tokens
,
...
...
@@ -299,6 +355,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
enable_dispatch_ll_layered
=
enable_dispatch_ll_layered
,
enable_combine_overlap
=
enable_combine_overlap
,
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
# Destroy the buffer runtime and communication group
...
...
@@ -310,7 +368,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
if
__name__
==
'__main__'
:
# TODO: you may modify NUMA binding for less CPU overhead
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low-latency EP kernels'
)
group
=
parser
.
add_argument_group
(
title
=
'extra distributed args'
)
group
.
add_argument
(
'--rank'
,
default
=-
int
(
os
.
getenv
(
'OMPI_COMM_WORLD_RANK'
,
'0'
)),
type
=
int
,
help
=
'node rank for distributed training'
)
...
...
@@ -331,6 +389,10 @@ if __name__ == '__main__':
parser
.
add_argument
(
"--pressure-test"
,
action
=
'store_true'
,
help
=
'Whether to do pressure test'
)
parser
.
add_argument
(
"--shrink-test"
,
action
=
'store_true'
,
help
=
'Whether to simulate failure and test shrink mode'
)
parser
.
add_argument
(
'--use-logfmt'
,
action
=
'store_true'
,
help
=
'Whether to test LogFMT combine'
)
# 新版 sbo 需要的
parser
.
add_argument
(
'--enable-dispatch-ll-layered'
,
action
=
'store_true'
,
help
=
'Enable low-latency layered dispatch optimization'
)
parser
.
add_argument
(
"--enable-combine-overlap"
,
action
=
'store_true'
,
help
=
'Enable GEMM-compute/communication overlap in the combine phase'
)
args
=
parser
.
parse_args
()
if
args
.
world_size
>
args
.
num_processes
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment