Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
61bc0aff
Commit
61bc0aff
authored
Mar 04, 2026
by
lishen
Browse files
修改torchrun启动的测试
parent
e195b4fe
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
248 additions
and
422 deletions
+248
-422
tests/1.sh
tests/1.sh
+6
-6
tests/2.sh
tests/2.sh
+6
-6
tests/test_low_latency.py
tests/test_low_latency.py
+236
-95
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+0
-315
No files found.
1.sh
→
tests/
1.sh
View file @
61bc0aff
...
@@ -4,7 +4,7 @@ export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
...
@@ -4,7 +4,7 @@ export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
ROCSHMEM_TOPO_FILE_FORCE
=
tests
/topo.config
export
ROCSHMEM_TOPO_FILE_FORCE
=
.
/topo.config
# duSHMEM
# duSHMEM
export
LD_LIBRARY_PATH
=
/opt/dtk/dushmem/lib:
$LD_LIBRARY_PATH
export
LD_LIBRARY_PATH
=
/opt/dtk/dushmem/lib:
$LD_LIBRARY_PATH
...
@@ -13,10 +13,10 @@ export NVSHMEM_SYMMETRIC_SIZE=10737418240
...
@@ -13,10 +13,10 @@ export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common
# common
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
PYTHONPATH
=
$(
pwd
)
export
PYTHONPATH
=
$(
pwd
)
/../
# test
# test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234
tests
/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234
.
/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234
tests
/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234
.
/test_low_latency.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234
tests
/test_low_latency
_new
.py
--pressure-test
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234
.
/test_low_latency.py
--pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234
tests
/test_internode.py --test-ll-compatibility
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234
.
/test_internode.py --test-ll-compatibility
2.sh
→
tests/
2.sh
View file @
61bc0aff
...
@@ -4,7 +4,7 @@ export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
...
@@ -4,7 +4,7 @@ export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
ROCSHMEM_TOPO_FILE_FORCE
=
tests
/topo.config
export
ROCSHMEM_TOPO_FILE_FORCE
=
.
/topo.config
# duSHMEM
# duSHMEM
export
LD_LIBRARY_PATH
=
/opt/dtk/dushmem/lib:
$LD_LIBRARY_PATH
export
LD_LIBRARY_PATH
=
/opt/dtk/dushmem/lib:
$LD_LIBRARY_PATH
...
@@ -13,10 +13,10 @@ export NVSHMEM_SYMMETRIC_SIZE=10737418240
...
@@ -13,10 +13,10 @@ export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common
# common
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
PYTHONPATH
=
$(
pwd
)
export
PYTHONPATH
=
$(
pwd
)
/../
# test
# test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234
tests
/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234
.
/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234
tests
/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234
.
/test_low_latency.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234
tests
/test_low_latency
_new
.py
--pressure-test
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234
.
/test_low_latency.py
--pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234
tests
/test_internode.py --test-ll-compatibility
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234
.
/test_internode.py --test-ll-compatibility
tests/test_low_latency.py
View file @
61bc0aff
This diff is collapsed.
Click to expand it.
tests/test_low_latency_new.py
deleted
100644 → 0
View file @
e195b4fe
import
argparse
import
random
import
torch
import
torch.distributed
as
dist
from
functools
import
partial
from
typing
import
Literal
,
Set
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_pg_back
,
per_token_cast_pc_back
def
simulate_failure_and_skip
(
rank
:
int
,
api
:
Literal
[
"dispatch"
,
"combine"
,
"clean"
],
expected_masked_ranks
:
Set
[
int
]):
# Simulates rank failure when the rank first calls the corresponding communication API
failed_api_ranks
=
{
# API -> rank to fail (rank fails when it first calls the corresponding communication API)
'dispatch'
:
1
,
'combine'
:
3
,
'clean'
:
5
}
if
rank
in
expected_masked_ranks
:
# Rank already failed
return
True
if
api
in
failed_api_ranks
.
keys
():
expected_masked_ranks
.
add
(
failed_api_ranks
[
api
])
if
failed_api_ranks
[
api
]
==
rank
:
print
(
f
"Rank
{
rank
}
failed when first calling
{
api
}
communication API, exit..."
,
flush
=
True
)
return
True
return
False
def
query_mask_buffer_and_check
(
api
:
Literal
[
"dispatch"
,
"combine"
,
"clean"
],
buffer
:
deep_ep
.
Buffer
,
mask_status
:
torch
.
Tensor
,
expected_masked_ranks
:
Set
[
int
]):
buffer
.
low_latency_query_mask_buffer
(
mask_status
)
assert
set
(
mask_status
.
nonzero
().
squeeze
(
-
1
).
tolist
())
==
expected_masked_ranks
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
rank
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
use_logfmt
:
bool
=
False
,
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
assert
num_experts
%
num_ranks
==
0
num_local_experts
=
num_experts
//
num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset
=
128
assert
num_ranks
-
rank_offset
<
257
,
'Too many ranks (exceeding test precision limit)'
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
(
rank
-
rank_offset
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
'cuda'
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
x_list
=
[
x
]
for
_
in
range
(
4
if
use_logfmt
else
0
):
# NOTES: make more LogFMT casts and also with some BF16
x_list
.
append
(
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
0.5
*
random
.
random
())
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list
.
append
(
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
0.1
)
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_weights
=
torch
.
randn
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
# Randomly mask some positions
for
_
in
range
(
10
):
topk_idx
[
random
.
randint
(
0
,
num_tokens
-
1
),
random
.
randint
(
0
,
num_topk
-
1
)]
=
-
1
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
# For failure simulation and shrink testing
mask_status
=
torch
.
zeros
((
num_ranks
,),
dtype
=
torch
.
int
,
device
=
'cuda'
)
expected_masked_ranks
=
set
()
# Check dispatch correctness
do_check
=
True
hash_value
,
num_times
=
0
,
0
for
x_i
,
current_x
in
enumerate
(
x_list
):
for
return_recv_hook
in
(
False
,
True
):
for
quant_type
in
(
0
,
1
,
2
,
3
,
):
# 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant
=
quant_type
>
0
for
fp8_round_scale
in
(
False
,
True
)
if
quant_type
!=
3
else
(
True
,
):
for
quant_group_size
in
(
0
,
128
,)
if
quant_type
>=
2
else
(
0
,
):
if
quant_type
==
3
and
(
fp8_round_scale
==
False
or
quant_group_size
==
0
):
continue
num_times
+=
1
for
_
in
range
((
num_times
%
2
)
+
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
quant_type
=
quant_type
,
fp8_round_scale
=
fp8_round_scale
,
quant_group_size
=
quant_group_size
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_quant
else
packed_recv_x
if
not
dispatch_use_quant
:
simulated_gemm_x
=
packed_recv_x
.
clone
()
elif
quant_group_size
==
0
:
simulated_gemm_x
=
per_token_cast_pc_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
reshape
(
-
1
)).
view
(
packed_recv_x
[
0
].
shape
)
elif
quant_group_size
==
128
:
simulated_gemm_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
if
not
dispatch_use_quant
:
recv_x
=
packed_recv_x
[
i
]
elif
quant_group_size
==
0
:
recv_x
=
per_token_cast_pc_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
elif
quant_group_size
==
128
:
recv_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
(
),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
()
}
'
if
num_valid_tokens
==
0
:
continue
# Check received data
if
current_x
is
x
:
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
dispatch_use_quant
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
if
quant_group_size
!=
0
:
if
fp8_round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
if
not
fp8_round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
if
dispatch_use_quant
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
for
zero_copy
in
(
False
,
)
if
use_logfmt
else
(
False
,
True
,
):
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
use_logfmt
=
use_logfmt
,
async_finish
=
not
return_recv_hook
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
do_check
:
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
# if not fp8_round_scale:
assert
diff
<
(
9e-4
if
dispatch_use_quant
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_quant=
{
dispatch_use_quant
}
, zero_copy=
{
zero_copy
}
'
hash_value
^=
hash_tensor
(
combined_x
)
if
rank
==
0
:
print
(
f
"data:
{
x_i
}
, return_recv_hook:
{
return_recv_hook
}
, quant_type:
{
quant_type
}
, "
,
f
"fp8_round_scale:
{
fp8_round_scale
}
, quant_group_size:
{
quant_group_size
}
pass"
)
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_1
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_0
@
mat_1
hook
()
# noinspection PyShadowingNames
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
quant_type
=
2
,
quant_group_size
=
0
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
use_logfmt
=
use_logfmt
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# Calculate bandwidth
num_fp8_bytes
,
num_bf16_bytes
=
(
hidden
+
hidden
/
128
*
4
+
16
),
hidden
*
2
num_logfmt10_bytes
=
hidden
*
10
/
8
+
hidden
/
128
*
4
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
for
i
in
range
(
num_tokens
):
num_selections
=
(
topk_idx
[
i
]
!=
-
1
).
sum
().
item
()
num_dispatch_comm_bytes
+=
num_fp8_bytes
*
num_selections
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
# Dispatch + combine testing
avg_t
,
min_t
,
max_t
=
bench
(
partial
(
test_func
,
return_recv_hook
=
False
))
print
(
f
'[rank
{
rank
}
] Dispatch + combine bandwidth:
{
(
num_dispatch_comm_bytes
+
num_combine_comm_bytes
)
/
1e9
/
avg_t
:.
2
f
}
GB/s, '
f
'avg_t=
{
avg_t
*
1e6
:.
2
f
}
us, min_t=
{
min_t
*
1e6
:.
2
f
}
us, max_t=
{
max_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
# Separate profiling
for
return_recv_hook
in
(
False
,
True
):
group
.
barrier
()
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
,
num_kernels_per_period
=
2
if
return_recv_hook
else
1
)
if
not
return_recv_hook
:
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us | '
f
'Combine bandwidth:
{
num_combine_comm_bytes
/
1e9
/
combine_t
:.
2
f
}
GB/s, avg_t=
{
combine_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
else
:
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
[
0
]
*
1e6
:.
2
f
}
+
{
dispatch_t
[
1
]
*
1e6
:.
2
f
}
us | '
f
'Combine send/recv time:
{
combine_t
[
0
]
*
1e6
:.
2
f
}
+
{
combine_t
[
1
]
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
return
hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
:
argparse
.
Namespace
):
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
num_topk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
if
local_rank
==
0
:
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
num_qps_per_rank
=
num_experts
//
num_ranks
,
allow_nvlink_for_low_latency_mode
=
not
args
.
disable_nvlink
,
explicitly_destroy
=
True
,
allow_mnnvl
=
args
.
allow_mnnvl
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
1
)
do_pressure_test
=
args
.
pressure_test
for
seed
in
range
(
int
(
1e9
)
if
do_pressure_test
else
0
):
if
local_rank
==
0
:
print
(
f
'Testing with seed
{
seed
}
...'
,
flush
=
True
)
ref_hash
=
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
seed
)
for
_
in
range
(
20
):
assert
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
# Destroy the buffer runtime and communication group
buffer
.
destroy
()
dist
.
barrier
()
dist
.
destroy_process_group
()
if
__name__
==
'__main__'
:
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low-latency EP kernels'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
128
,
help
=
'Number of tokens (default: 128)'
)
parser
.
add_argument
(
'--hidden'
,
type
=
int
,
default
=
7168
,
help
=
'Hidden dimension size (default: 7168)'
)
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
help
=
'Number of top-k experts (default: 8)'
)
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
288
,
help
=
'Number of experts (default: 288)'
)
parser
.
add_argument
(
'--allow-mnnvl'
,
action
=
"store_true"
,
help
=
'Allow MNNVL for communication'
)
parser
.
add_argument
(
'--disable-nvlink'
,
action
=
'store_true'
,
help
=
'Whether to disable NVLink for testing'
)
parser
.
add_argument
(
"--pressure-test"
,
action
=
'store_true'
,
help
=
'Whether to do pressure test'
)
parser
.
add_argument
(
"--shrink-test"
,
action
=
'store_true'
,
help
=
'Whether to simulate failure and test shrink mode'
)
parser
.
add_argument
(
'--use-logfmt'
,
action
=
'store_true'
,
help
=
'Whether to test LogFMT combine'
)
args
=
parser
.
parse_args
()
num_processes
=
args
.
num_processes
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment