Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
e57e9270
Commit
e57e9270
authored
Feb 05, 2026
by
lishen
Browse files
量化测试代码修改对应的tests修改
parent
830124e1
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
17 deletions
+17
-17
tests/test_internode.py
tests/test_internode.py
+3
-3
tests/test_intranode.py
tests/test_intranode.py
+4
-4
tests/test_low_latency.py
tests/test_low_latency.py
+3
-3
tests/utils.py
tests/utils.py
+7
-7
No files found.
tests/test_internode.py
View file @
e57e9270
...
@@ -6,7 +6,7 @@ import torch.distributed as dist
...
@@ -6,7 +6,7 @@ import torch.distributed as dist
# noinspection PyUnresolvedReferences
# noinspection PyUnresolvedReferences
import
deep_ep
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
create_grouped_scores
,
inplace_unique
,
per_token_cast_to_fp8
,
per_token_cast_back
,
hash_tensor
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
create_grouped_scores
,
inplace_unique
,
per_token_cast_to_fp8
,
per_token_cast_
pg_
back
,
hash_tensor
# Test compatibility with low latency functions
# Test compatibility with low latency functions
import
test_low_latency
import
test_low_latency
...
@@ -127,7 +127,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -127,7 +127,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
hash_value
+=
hash_tensor
(
recv_x
[
0
])
hash_value
+=
hash_tensor
(
recv_x
[
0
])
hash_value
+=
hash_tensor
(
recv_x
[
1
])
hash_value
+=
hash_tensor
(
recv_x
[
1
])
recv_x
=
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
recv_x
=
per_token_cast_
pg_
back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
# Checks
# Checks
recv_gbl_rank_prefix_sum
=
handle
[
-
4
]
recv_gbl_rank_prefix_sum
=
handle
[
-
4
]
...
@@ -153,7 +153,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -153,7 +153,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
dispatch_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
dispatch_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
recv_x
,
_
,
_
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
recv_x
,
_
,
_
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
recv_x
=
per_token_cast_
pg_
back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
if
not
is_rand
:
if
not
is_rand
:
check_data
(
recv_x
,
recv_gbl_rank_prefix_sum
)
check_data
(
recv_x
,
recv_gbl_rank_prefix_sum
)
...
...
tests/test_intranode.py
View file @
e57e9270
...
@@ -5,7 +5,7 @@ import torch.distributed as dist
...
@@ -5,7 +5,7 @@ import torch.distributed as dist
# noinspection PyUnresolvedReferences
# noinspection PyUnresolvedReferences
import
deep_ep
import
deep_ep
from
utils
import
init_dist
,
bench
,
calc_diff
,
inplace_unique
,
per_token_cast_to_fp8
,
per_token_cast_back
from
utils
import
init_dist
,
bench
,
calc_diff
,
inplace_unique
,
per_token_cast_to_fp8
,
per_token_cast_
pg_
back
# Test compatibility with low latency functions
# Test compatibility with low latency functions
import
test_low_latency
import
test_low_latency
...
@@ -99,7 +99,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
...
@@ -99,7 +99,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
dispatch_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
dispatch_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
recv_x
,
recv_topk_idx
,
recv_topk_weights
,
recv_num_tokens_per_expert_list
,
handle
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
recv_x
,
recv_topk_idx
,
recv_topk_weights
,
recv_num_tokens_per_expert_list
,
handle
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
recv_x
=
per_token_cast_
pg_
back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
# Checks
# Checks
rank_prefix_matrix
=
handle
[
0
]
rank_prefix_matrix
=
handle
[
0
]
...
@@ -126,7 +126,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
...
@@ -126,7 +126,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
dispatch_args
.
update
({
'num_worst_tokens'
:
num_worst_tokens
})
dispatch_args
.
update
({
'num_worst_tokens'
:
num_worst_tokens
})
recv_worst_x
,
recv_worst_topk_idx
,
recv_worst_topk_weights
,
empty_list
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
recv_worst_x
,
recv_worst_topk_idx
,
recv_worst_topk_weights
,
empty_list
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
event
.
current_stream_wait
()
if
async_mode
else
()
recv_worst_x
=
per_token_cast_back
(
*
recv_worst_x
)
if
isinstance
(
recv_worst_x
,
tuple
)
else
recv_worst_x
recv_worst_x
=
per_token_cast_
pg_
back
(
*
recv_worst_x
)
if
isinstance
(
recv_worst_x
,
tuple
)
else
recv_worst_x
assert
len
(
empty_list
)
==
0
assert
len
(
empty_list
)
==
0
assert
num_worst_tokens
==
recv_worst_x
.
size
(
0
)
assert
num_worst_tokens
==
recv_worst_x
.
size
(
0
)
assert
num_worst_tokens
==
recv_worst_topk_idx
.
size
(
0
)
assert
num_worst_tokens
==
recv_worst_topk_idx
.
size
(
0
)
...
@@ -143,7 +143,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
...
@@ -143,7 +143,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
dispatch_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
dispatch_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
recv_x
,
_
,
_
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
recv_x
,
_
,
_
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
recv_x
=
per_token_cast_
pg_
back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
if
current_x
is
not
x_pure_rand
:
if
current_x
is
not
x_pure_rand
:
check_data
(
recv_x
,
rank_prefix_matrix
)
check_data
(
recv_x
,
rank_prefix_matrix
)
...
...
tests/test_low_latency.py
View file @
e57e9270
...
@@ -4,7 +4,7 @@ import torch.distributed as dist
...
@@ -4,7 +4,7 @@ import torch.distributed as dist
from
functools
import
partial
from
functools
import
partial
import
deep_ep
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_back
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_
pg_
back
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
...
@@ -44,7 +44,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -44,7 +44,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# print('run {}/{}, dispatch_use_fp8={}'.format(i + 1, num_times, dispatch_use_fp8))
# print('run {}/{}, dispatch_use_fp8={}'.format(i + 1, num_times, dispatch_use_fp8))
# return
# return
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
\
simulated_gemm_x
=
per_token_cast_
pg_
back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
\
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
# print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n")
# print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n")
# print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n")
# print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n")
...
@@ -53,7 +53,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -53,7 +53,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_x
=
per_token_cast_
pg_
back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
# Check expert indices
...
...
tests/utils.py
View file @
e57e9270
...
@@ -72,16 +72,16 @@ def per_token_cast_pg_back(x: torch.Tensor, x_scales: torch.Tensor):
...
@@ -72,16 +72,16 @@ def per_token_cast_pg_back(x: torch.Tensor, x_scales: torch.Tensor):
x_scales
=
x_scales
.
view
(
x
.
size
(
0
),
-
1
,
1
)
x_scales
=
x_scales
.
view
(
x
.
size
(
0
),
-
1
,
1
)
return
(
x_fp32_padded
*
x_scales
).
view
(
x_padded
.
shape
).
to
(
torch
.
bfloat16
)[:,:
n
].
contiguous
()
return
(
x_fp32_padded
*
x_scales
).
view
(
x_padded
.
shape
).
to
(
torch
.
bfloat16
)[:,:
n
].
contiguous
()
def
per_token_cast_pc_back
(
x
_int8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
def
per_token_cast_pc_back
(
x
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
if
x
_int8
.
numel
()
==
0
:
if
x
.
numel
()
==
0
:
return
x
_int8
.
to
(
torch
.
bfloat16
)
return
x
.
to
(
torch
.
bfloat16
)
assert
x
_int8
.
dim
()
==
2
assert
x
.
dim
()
==
2
m
,
n
=
x
_int8
.
shape
m
,
n
=
x
.
shape
aligned_n
=
align_up
(
n
,
128
)
aligned_n
=
align_up
(
n
,
128
)
x_
int8_
padded
=
torch
.
nn
.
functional
.
pad
(
x
_int8
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_padded
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_fp32_padded
=
x_
int8_
padded
.
to
(
torch
.
float32
).
view
(
m
,
-
1
,
1
)
x_fp32_padded
=
x_padded
.
to
(
torch
.
float32
).
view
(
m
,
-
1
,
1
)
x_scales
=
x_scales
.
view
(
m
,
-
1
,
1
).
to
(
torch
.
float32
)
x_scales
=
x_scales
.
view
(
m
,
-
1
,
1
).
to
(
torch
.
float32
)
x_deq
=
(
x_fp32_padded
*
x_scales
).
view
(
m
,
aligned_n
)
x_deq
=
(
x_fp32_padded
*
x_scales
).
view
(
m
,
aligned_n
)
return
x_deq
[:,
:
n
].
to
(
torch
.
bfloat16
).
contiguous
()
return
x_deq
[:,
:
n
].
to
(
torch
.
bfloat16
).
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment