Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
61bc0aff
Commit
61bc0aff
authored
Mar 04, 2026
by
lishen
Browse files
修改torchrun启动的测试
parent
e195b4fe
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
248 additions
and
422 deletions
+248
-422
tests/1.sh
tests/1.sh
+6
-6
tests/2.sh
tests/2.sh
+6
-6
tests/test_low_latency.py
tests/test_low_latency.py
+236
-95
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+0
-315
No files found.
1.sh
→
tests/
1.sh
View file @
61bc0aff
...
@@ -4,7 +4,7 @@ export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
...
@@ -4,7 +4,7 @@ export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
ROCSHMEM_TOPO_FILE_FORCE
=
tests
/topo.config
export
ROCSHMEM_TOPO_FILE_FORCE
=
.
/topo.config
# duSHMEM
# duSHMEM
export
LD_LIBRARY_PATH
=
/opt/dtk/dushmem/lib:
$LD_LIBRARY_PATH
export
LD_LIBRARY_PATH
=
/opt/dtk/dushmem/lib:
$LD_LIBRARY_PATH
...
@@ -13,10 +13,10 @@ export NVSHMEM_SYMMETRIC_SIZE=10737418240
...
@@ -13,10 +13,10 @@ export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common
# common
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
PYTHONPATH
=
$(
pwd
)
export
PYTHONPATH
=
$(
pwd
)
/../
# test
# test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234
tests
/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234
.
/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234
tests
/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234
.
/test_low_latency.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234
tests
/test_low_latency
_new
.py
--pressure-test
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234
.
/test_low_latency.py
--pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234
tests
/test_internode.py --test-ll-compatibility
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234
.
/test_internode.py --test-ll-compatibility
2.sh
→
tests/
2.sh
View file @
61bc0aff
...
@@ -4,7 +4,7 @@ export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
...
@@ -4,7 +4,7 @@ export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
ROCSHMEM_TOPO_FILE_FORCE
=
tests
/topo.config
export
ROCSHMEM_TOPO_FILE_FORCE
=
.
/topo.config
# duSHMEM
# duSHMEM
export
LD_LIBRARY_PATH
=
/opt/dtk/dushmem/lib:
$LD_LIBRARY_PATH
export
LD_LIBRARY_PATH
=
/opt/dtk/dushmem/lib:
$LD_LIBRARY_PATH
...
@@ -13,10 +13,10 @@ export NVSHMEM_SYMMETRIC_SIZE=10737418240
...
@@ -13,10 +13,10 @@ export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common
# common
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
PYTHONPATH
=
$(
pwd
)
export
PYTHONPATH
=
$(
pwd
)
/../
# test
# test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234
tests
/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234
.
/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234
tests
/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234
.
/test_low_latency.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234
tests
/test_low_latency
_new
.py
--pressure-test
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234
.
/test_low_latency.py
--pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234
tests
/test_internode.py --test-ll-compatibility
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234
.
/test_internode.py --test-ll-compatibility
tests/test_low_latency.py
View file @
61bc0aff
import
argparse
import
random
import
random
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
functools
import
partial
from
functools
import
partial
from
typing
import
Literal
,
Set
import
deep_ep
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_pg_back
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_pg_back
,
per_token_cast_pc_back
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
def
simulate_failure_and_skip
(
rank
:
int
,
api
:
Literal
[
"dispatch"
,
"combine"
,
"clean"
],
expected_masked_ranks
:
Set
[
int
]):
rank
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
seed
:
int
=
0
):
# Simulates rank failure when the rank first calls the corresponding communication API
failed_api_ranks
=
{
# API -> rank to fail (rank fails when it first calls the corresponding communication API)
'dispatch'
:
1
,
'combine'
:
3
,
'clean'
:
5
}
if
rank
in
expected_masked_ranks
:
# Rank already failed
return
True
if
api
in
failed_api_ranks
.
keys
():
expected_masked_ranks
.
add
(
failed_api_ranks
[
api
])
if
failed_api_ranks
[
api
]
==
rank
:
print
(
f
"Rank
{
rank
}
failed when first calling
{
api
}
communication API, exit..."
,
flush
=
True
)
return
True
return
False
def
query_mask_buffer_and_check
(
api
:
Literal
[
"dispatch"
,
"combine"
,
"clean"
],
buffer
:
deep_ep
.
Buffer
,
mask_status
:
torch
.
Tensor
,
expected_masked_ranks
:
Set
[
int
]):
buffer
.
low_latency_query_mask_buffer
(
mask_status
)
assert
set
(
mask_status
.
nonzero
().
squeeze
(
-
1
).
tolist
())
==
expected_masked_ranks
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
rank
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
use_logfmt
:
bool
=
False
,
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
assert
num_experts
%
num_ranks
==
0
assert
num_experts
%
num_ranks
==
0
num_local_experts
=
num_experts
//
num_ranks
num_local_experts
=
num_experts
//
num_ranks
# NOTES: the integers greater than 256 exceed
s
the BF16 precision limit
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset
=
128
rank_offset
=
128
assert
num_ranks
-
rank_offset
<
257
,
'Too many ranks (exceeding test precision limit)'
assert
num_ranks
-
rank_offset
<
257
,
'Too many ranks (exceeding test precision limit)'
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
(
rank
-
rank_offset
)
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
(
rank
-
rank_offset
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
'cuda'
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
'cuda'
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
x_list
=
[
x
]
for
_
in
range
(
4
if
use_logfmt
else
0
):
# NOTES: make more LogFMT casts and also with some BF16
x_list
.
append
(
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
0.5
*
random
.
random
())
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list
.
append
(
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
0.1
)
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_weights
=
torch
.
randn
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
topk_weights
=
torch
.
randn
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
# Randomly mask some positions
# Randomly mask some positions
for
i
in
range
(
10
):
for
_
in
range
(
10
):
topk_idx
[
random
.
randint
(
0
,
num_tokens
-
1
),
random
.
randint
(
0
,
num_topk
-
1
)]
=
-
1
topk_idx
[
random
.
randint
(
0
,
num_tokens
-
1
),
random
.
randint
(
0
,
num_topk
-
1
)]
=
-
1
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
# For failure simulation and shrink testing
mask_status
=
torch
.
zeros
((
num_ranks
,),
dtype
=
torch
.
int
,
device
=
'cuda'
)
expected_masked_ranks
=
set
()
# Check dispatch correctness
# Check dispatch correctness
do_check
=
True
do_check
=
True
hash_value
,
num_times
=
0
,
0
hash_value
,
num_times
=
0
,
0
for
x_i
,
current_x
in
enumerate
(
x_list
):
for
return_recv_hook
in
(
False
,
True
):
for
quant_type
in
(
0
,
1
,
2
,
3
,
):
# 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant
=
quant_type
>
0
for
fp8_round_scale
in
(
False
,
True
)
if
quant_type
!=
3
else
(
True
,
):
for
quant_group_size
in
(
0
,
128
,)
if
quant_type
>=
2
else
(
0
,
):
if
quant_type
==
3
and
(
fp8_round_scale
==
False
or
quant_group_size
==
0
):
continue
for
return_recv_hook
in
(
False
,
True
):
num_times
+=
1
for
dispatch_use_fp8
in
(
False
,
True
):
for
_
in
range
((
num_times
%
2
)
+
1
):
num_times
+=
1
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
for
i
in
range
((
num_times
%
2
)
+
1
):
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
quant_type
=
quant_type
,
fp8_round_scale
=
fp8_round_scale
,
quant_group_size
=
quant_group_size
,
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_quant
else
packed_recv_x
if
not
dispatch_use_quant
:
# print('run {}/{}, dispatch_use_fp8={}'.format(i + 1, num_times, dispatch_use_fp8))
simulated_gemm_x
=
packed_recv_x
.
clone
()
# return
elif
quant_group_size
==
0
:
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_pc_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
reshape
(
-
1
)).
view
(
packed_recv_x
[
0
].
shape
)
simulated_gemm_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
\
elif
quant_group_size
==
128
:
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
simulated_gemm_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
# print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n")
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
# print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n")
expert_id
=
rank
*
num_local_experts
+
i
# print(f"simulated_gemm_x{simulated_gemm_x.cpu()}")
if
not
dispatch_use_quant
:
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
recv_x
=
packed_recv_x
[
i
]
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
elif
quant_group_size
==
0
:
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
recv_x
=
per_token_cast_pc_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
expert_id
=
rank
*
num_local_experts
+
i
elif
quant_group_size
==
128
:
recv_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
num_valid_tokens
=
recv_count
.
item
()
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
().
item
()
}
'
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
# Check received data
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
(
recv_x
=
recv_x
[:
num_valid_tokens
]
),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
()
}
'
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
if
num_valid_tokens
==
0
:
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
continue
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
# Check received data
for
j
in
range
(
num_ranks
):
if
current_x
is
x
:
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
recv_x
=
recv_x
[:
num_valid_tokens
]
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
][:
-
128
]
-
j
).
sum
().
item
()
==
0
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
if
dispatch_use_fp8
:
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
if
dispatch_use_quant
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
# Check combine correctness
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
for
zero_copy
in
(
False
,
True
):
if
quant_group_size
!=
0
:
if
zero_copy
:
if
fp8_round_scale
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
else
:
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
async_finish
=
not
return_recv_hook
,
for
j
in
range
(
num_ranks
):
return_recv_hook
=
return_recv_hook
,
out
=
out
)
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
not
fp8_round_scale
:
if
do_check
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
if
dispatch_use_quant
:
assert
diff
<
1e-5
,
f
'Error: diff=
{
diff
}
'
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
combined_x
)
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
for
zero_copy
in
(
False
,
)
if
use_logfmt
else
(
False
,
True
,
):
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
use_logfmt
=
use_logfmt
,
async_finish
=
not
return_recv_hook
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
do_check
:
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
# if not fp8_round_scale:
assert
diff
<
(
9e-4
if
dispatch_use_quant
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_quant=
{
dispatch_use_quant
}
, zero_copy=
{
zero_copy
}
'
hash_value
^=
hash_tensor
(
combined_x
)
if
rank
==
0
:
print
(
f
"data:
{
x_i
}
, return_recv_hook:
{
return_recv_hook
}
, quant_type:
{
quant_type
}
, "
,
f
"fp8_round_scale:
{
fp8_round_scale
}
, quant_group_size:
{
quant_group_size
}
pass"
)
# noinspection PyShadowingNames
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
def
large_gemm_with_hook
(
hook
):
...
@@ -101,19 +188,23 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -101,19 +188,23 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hook
()
hook
()
# noinspection PyShadowingNames
# noinspection PyShadowingNames
def
test_func
(
zero_copy
:
bool
,
return_recv_hook
:
bool
):
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
quant_type
=
2
,
quant_group_size
=
0
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
if
zero_copy
:
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
topk_idx
,
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
topk_weights
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
)
handle
,
use_logfmt
=
use_logfmt
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# Calculate bandwidth
# Calculate bandwidth
num_fp8_bytes
,
num_bf16_bytes
=
(
hidden
+
hidden
/
128
*
4
+
16
),
hidden
*
2
num_fp8_bytes
,
num_bf16_bytes
=
(
hidden
+
hidden
/
128
*
4
+
16
),
hidden
*
2
num_logfmt10_bytes
=
hidden
*
10
/
8
+
hidden
/
128
*
4
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
for
i
in
range
(
num_tokens
):
for
i
in
range
(
num_tokens
):
num_selections
=
(
topk_idx
[
i
]
!=
-
1
).
sum
().
item
()
num_selections
=
(
topk_idx
[
i
]
!=
-
1
).
sum
().
item
()
...
@@ -121,54 +212,104 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -121,54 +212,104 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
# Dispatch + combine testing
# Dispatch + combine testing
avg_t
,
min_t
,
max_t
=
bench
(
partial
(
test_func
,
zero_copy
=
False
,
return_recv_hook
=
False
))
avg_t
,
min_t
,
max_t
=
bench
(
partial
(
test_func
,
return_recv_hook
=
False
))
print
(
f
'[rank
{
rank
}
] Dispatch + combine bandwidth:
{
(
num_dispatch_comm_bytes
+
num_combine_comm_bytes
)
/
1e9
/
avg_t
:.
2
f
}
GB/s, '
print
(
f
'[rank
{
rank
}
] Dispatch + combine bandwidth:
{
(
num_dispatch_comm_bytes
+
num_combine_comm_bytes
)
/
1e9
/
avg_t
:.
2
f
}
GB/s, '
f
'avg_t=
{
avg_t
*
1e6
:.
2
f
}
us, min_t=
{
min_t
*
1e6
:.
2
f
}
us, max_t=
{
max_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
f
'avg_t=
{
avg_t
*
1e6
:.
2
f
}
us, min_t=
{
min_t
*
1e6
:.
2
f
}
us, max_t=
{
max_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
# Separate profiling
# Separate profiling
for
return_recv_hook
in
(
False
,
True
):
for
return_recv_hook
in
(
False
,
True
):
group
.
barrier
()
group
.
barrier
()
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
return_recv_hook
=
return_recv_hook
),
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
zero_copy
=
True
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
(
'dispatch'
,
'combine'
),
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
)
suppress_kineto_output
=
True
,
num_kernels_per_period
=
2
if
return_recv_hook
else
1
)
if
not
return_recv_hook
:
if
not
return_recv_hook
:
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us | '
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us | '
f
'Combine bandwidth:
{
num_combine_comm_bytes
/
1e9
/
combine_t
:.
2
f
}
GB/s, avg_t=
{
combine_t
*
1e6
:.
2
f
}
us'
)
f
'Combine bandwidth:
{
num_combine_comm_bytes
/
1e9
/
combine_t
:.
2
f
}
GB/s, avg_t=
{
combine_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
else
:
else
:
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
*
2
*
1e6
:.
2
f
}
us | '
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
[
0
]
*
1e6
:.
2
f
}
+
{
dispatch_t
[
1
]
*
1e6
:.
2
f
}
us | '
f
'Combine send/recv time:
{
combine_t
*
2
*
1e6
:.
2
f
}
us'
)
f
'Combine send/recv time:
{
combine_t
[
0
]
*
1e6
:.
2
f
}
+
{
combine_t
[
1
]
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
return
hash_value
return
hash_value
# noinspection PyUnboundLocalVariable
# noinspection PyUnboundLocalVariable
,PyShadowingNames
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
):
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
:
argparse
.
Namespace
):
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
# The default setting of deepEP upstream is below:
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
num_to
kens
,
hidden
,
num_topk
,
num_experts
=
128
,
7168
,
8
,
256
num_to
pk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
buffer
=
deep_ep
.
Buffer
(
group
,
num_qps_per_rank
=
num_experts
//
num_ranks
,
explicitly_destroy
=
True
)
num_rdma_bytes
=
num_rdma_bytes
,
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
low_latency_mode
=
True
,
num_qps_per_rank
=
num_experts
//
num_ranks
,
allow_nvlink_for_low_latency_mode
=
not
args
.
disable_nvlink
,
explicitly_destroy
=
True
,
allow_mnnvl
=
args
.
allow_mnnvl
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
1
)
do_pressure_test
=
False
do_pressure_test
=
args
.
pressure_test
for
seed
in
range
(
int
(
1e9
)
if
do_pressure_test
else
0
):
for
seed
in
range
(
int
(
1e9
)
if
do_pressure_test
else
0
):
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
f
'Testing with seed
{
seed
}
...'
,
flush
=
True
)
print
(
f
'Testing with seed
{
seed
}
...'
,
flush
=
True
)
ref_hash
=
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
seed
)
ref_hash
=
test_main
(
num_tokens
,
for
i
in
range
(
20
):
hidden
,
assert
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
seed
)
for
_
in
range
(
20
):
assert
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
# Destroy the buffer runtime and communication group
buffer
.
destroy
()
buffer
.
destroy
()
dist
.
barrier
()
dist
.
barrier
()
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
print
(
"main start..."
)
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: you may modify NUMA binding for less CPU overhead
num_processes
=
8
# TODO: buggy with `num_tokens=512`
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,),
nprocs
=
num_processes
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low-latency EP kernels'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
128
,
help
=
'Number of tokens (default: 128)'
)
parser
.
add_argument
(
'--hidden'
,
type
=
int
,
default
=
7168
,
help
=
'Hidden dimension size (default: 7168)'
)
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
help
=
'Number of top-k experts (default: 8)'
)
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
288
,
help
=
'Number of experts (default: 288)'
)
parser
.
add_argument
(
'--allow-mnnvl'
,
action
=
"store_true"
,
help
=
'Allow MNNVL for communication'
)
parser
.
add_argument
(
'--disable-nvlink'
,
action
=
'store_true'
,
help
=
'Whether to disable NVLink for testing'
)
parser
.
add_argument
(
"--pressure-test"
,
action
=
'store_true'
,
help
=
'Whether to do pressure test'
)
parser
.
add_argument
(
"--shrink-test"
,
action
=
'store_true'
,
help
=
'Whether to simulate failure and test shrink mode'
)
parser
.
add_argument
(
'--use-logfmt'
,
action
=
'store_true'
,
help
=
'Whether to test LogFMT combine'
)
args
=
parser
.
parse_args
()
num_processes
=
args
.
num_processes
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
tests/test_low_latency_new.py
deleted
100644 → 0
View file @
e195b4fe
import
argparse
import
random
import
torch
import
torch.distributed
as
dist
from
functools
import
partial
from
typing
import
Literal
,
Set
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_pg_back
,
per_token_cast_pc_back
def
simulate_failure_and_skip
(
rank
:
int
,
api
:
Literal
[
"dispatch"
,
"combine"
,
"clean"
],
expected_masked_ranks
:
Set
[
int
]):
# Simulates rank failure when the rank first calls the corresponding communication API
failed_api_ranks
=
{
# API -> rank to fail (rank fails when it first calls the corresponding communication API)
'dispatch'
:
1
,
'combine'
:
3
,
'clean'
:
5
}
if
rank
in
expected_masked_ranks
:
# Rank already failed
return
True
if
api
in
failed_api_ranks
.
keys
():
expected_masked_ranks
.
add
(
failed_api_ranks
[
api
])
if
failed_api_ranks
[
api
]
==
rank
:
print
(
f
"Rank
{
rank
}
failed when first calling
{
api
}
communication API, exit..."
,
flush
=
True
)
return
True
return
False
def
query_mask_buffer_and_check
(
api
:
Literal
[
"dispatch"
,
"combine"
,
"clean"
],
buffer
:
deep_ep
.
Buffer
,
mask_status
:
torch
.
Tensor
,
expected_masked_ranks
:
Set
[
int
]):
buffer
.
low_latency_query_mask_buffer
(
mask_status
)
assert
set
(
mask_status
.
nonzero
().
squeeze
(
-
1
).
tolist
())
==
expected_masked_ranks
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
rank
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
use_logfmt
:
bool
=
False
,
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
assert
num_experts
%
num_ranks
==
0
num_local_experts
=
num_experts
//
num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset
=
128
assert
num_ranks
-
rank_offset
<
257
,
'Too many ranks (exceeding test precision limit)'
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
(
rank
-
rank_offset
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
'cuda'
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
x_list
=
[
x
]
for
_
in
range
(
4
if
use_logfmt
else
0
):
# NOTES: make more LogFMT casts and also with some BF16
x_list
.
append
(
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
0.5
*
random
.
random
())
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list
.
append
(
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
0.1
)
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_weights
=
torch
.
randn
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
# Randomly mask some positions
for
_
in
range
(
10
):
topk_idx
[
random
.
randint
(
0
,
num_tokens
-
1
),
random
.
randint
(
0
,
num_topk
-
1
)]
=
-
1
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
# For failure simulation and shrink testing
mask_status
=
torch
.
zeros
((
num_ranks
,),
dtype
=
torch
.
int
,
device
=
'cuda'
)
expected_masked_ranks
=
set
()
# Check dispatch correctness
do_check
=
True
hash_value
,
num_times
=
0
,
0
for
x_i
,
current_x
in
enumerate
(
x_list
):
for
return_recv_hook
in
(
False
,
True
):
for
quant_type
in
(
0
,
1
,
2
,
3
,
):
# 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant
=
quant_type
>
0
for
fp8_round_scale
in
(
False
,
True
)
if
quant_type
!=
3
else
(
True
,
):
for
quant_group_size
in
(
0
,
128
,)
if
quant_type
>=
2
else
(
0
,
):
if
quant_type
==
3
and
(
fp8_round_scale
==
False
or
quant_group_size
==
0
):
continue
num_times
+=
1
for
_
in
range
((
num_times
%
2
)
+
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
quant_type
=
quant_type
,
fp8_round_scale
=
fp8_round_scale
,
quant_group_size
=
quant_group_size
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_quant
else
packed_recv_x
if
not
dispatch_use_quant
:
simulated_gemm_x
=
packed_recv_x
.
clone
()
elif
quant_group_size
==
0
:
simulated_gemm_x
=
per_token_cast_pc_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
reshape
(
-
1
)).
view
(
packed_recv_x
[
0
].
shape
)
elif
quant_group_size
==
128
:
simulated_gemm_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
if
not
dispatch_use_quant
:
recv_x
=
packed_recv_x
[
i
]
elif
quant_group_size
==
0
:
recv_x
=
per_token_cast_pc_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
elif
quant_group_size
==
128
:
recv_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
(
),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
()
}
'
if
num_valid_tokens
==
0
:
continue
# Check received data
if
current_x
is
x
:
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
dispatch_use_quant
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
if
quant_group_size
!=
0
:
if
fp8_round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
if
not
fp8_round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
if
dispatch_use_quant
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
for
zero_copy
in
(
False
,
)
if
use_logfmt
else
(
False
,
True
,
):
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
use_logfmt
=
use_logfmt
,
async_finish
=
not
return_recv_hook
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
do_check
:
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
# if not fp8_round_scale:
assert
diff
<
(
9e-4
if
dispatch_use_quant
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_quant=
{
dispatch_use_quant
}
, zero_copy=
{
zero_copy
}
'
hash_value
^=
hash_tensor
(
combined_x
)
if
rank
==
0
:
print
(
f
"data:
{
x_i
}
, return_recv_hook:
{
return_recv_hook
}
, quant_type:
{
quant_type
}
, "
,
f
"fp8_round_scale:
{
fp8_round_scale
}
, quant_group_size:
{
quant_group_size
}
pass"
)
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_1
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_0
@
mat_1
hook
()
# noinspection PyShadowingNames
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
quant_type
=
2
,
quant_group_size
=
0
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
use_logfmt
=
use_logfmt
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# Calculate bandwidth
num_fp8_bytes
,
num_bf16_bytes
=
(
hidden
+
hidden
/
128
*
4
+
16
),
hidden
*
2
num_logfmt10_bytes
=
hidden
*
10
/
8
+
hidden
/
128
*
4
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
for
i
in
range
(
num_tokens
):
num_selections
=
(
topk_idx
[
i
]
!=
-
1
).
sum
().
item
()
num_dispatch_comm_bytes
+=
num_fp8_bytes
*
num_selections
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
# Dispatch + combine testing
avg_t
,
min_t
,
max_t
=
bench
(
partial
(
test_func
,
return_recv_hook
=
False
))
print
(
f
'[rank
{
rank
}
] Dispatch + combine bandwidth:
{
(
num_dispatch_comm_bytes
+
num_combine_comm_bytes
)
/
1e9
/
avg_t
:.
2
f
}
GB/s, '
f
'avg_t=
{
avg_t
*
1e6
:.
2
f
}
us, min_t=
{
min_t
*
1e6
:.
2
f
}
us, max_t=
{
max_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
# Separate profiling
for
return_recv_hook
in
(
False
,
True
):
group
.
barrier
()
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
,
num_kernels_per_period
=
2
if
return_recv_hook
else
1
)
if
not
return_recv_hook
:
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us | '
f
'Combine bandwidth:
{
num_combine_comm_bytes
/
1e9
/
combine_t
:.
2
f
}
GB/s, avg_t=
{
combine_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
else
:
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
[
0
]
*
1e6
:.
2
f
}
+
{
dispatch_t
[
1
]
*
1e6
:.
2
f
}
us | '
f
'Combine send/recv time:
{
combine_t
[
0
]
*
1e6
:.
2
f
}
+
{
combine_t
[
1
]
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
return
hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
:
argparse
.
Namespace
):
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
num_topk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
if
local_rank
==
0
:
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
num_qps_per_rank
=
num_experts
//
num_ranks
,
allow_nvlink_for_low_latency_mode
=
not
args
.
disable_nvlink
,
explicitly_destroy
=
True
,
allow_mnnvl
=
args
.
allow_mnnvl
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
1
)
do_pressure_test
=
args
.
pressure_test
for
seed
in
range
(
int
(
1e9
)
if
do_pressure_test
else
0
):
if
local_rank
==
0
:
print
(
f
'Testing with seed
{
seed
}
...'
,
flush
=
True
)
ref_hash
=
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
seed
)
for
_
in
range
(
20
):
assert
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
# Destroy the buffer runtime and communication group
buffer
.
destroy
()
dist
.
barrier
()
dist
.
destroy_process_group
()
if
__name__
==
'__main__'
:
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low-latency EP kernels'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
128
,
help
=
'Number of tokens (default: 128)'
)
parser
.
add_argument
(
'--hidden'
,
type
=
int
,
default
=
7168
,
help
=
'Hidden dimension size (default: 7168)'
)
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
help
=
'Number of top-k experts (default: 8)'
)
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
288
,
help
=
'Number of experts (default: 288)'
)
parser
.
add_argument
(
'--allow-mnnvl'
,
action
=
"store_true"
,
help
=
'Allow MNNVL for communication'
)
parser
.
add_argument
(
'--disable-nvlink'
,
action
=
'store_true'
,
help
=
'Whether to disable NVLink for testing'
)
parser
.
add_argument
(
"--pressure-test"
,
action
=
'store_true'
,
help
=
'Whether to do pressure test'
)
parser
.
add_argument
(
"--shrink-test"
,
action
=
'store_true'
,
help
=
'Whether to simulate failure and test shrink mode'
)
parser
.
add_argument
(
'--use-logfmt'
,
action
=
'store_true'
,
help
=
'Whether to test LogFMT combine'
)
args
=
parser
.
parse_args
()
num_processes
=
args
.
num_processes
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment