Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
d80bd14a
Commit
d80bd14a
authored
Nov 22, 2025
by
pengcheng888
Browse files
issue/78 - 在test文件夹目录,添加qwen3的moe和attention的耗时测试脚本
parent
4fd9d490
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
613 additions
and
0 deletions
+613
-0
test/qwen3_atteniton_test.py
test/qwen3_atteniton_test.py
+457
-0
test/qwen3_moe_test.py
test/qwen3_moe_test.py
+156
-0
No files found.
test/qwen3_atteniton_test.py
0 → 100755
View file @
d80bd14a
import
os
import
time
import
sys
import
safetensors
import
torch
from
transformers
import
AutoConfig
from
transformers
import
DynamicCache
from
transformers.models
import
qwen3_moe
WARMUPS
=
10
RUNS
=
100
PREFILL_TESTCASES
=
{
"seqlens"
:
[
64
,
128
,
256
,
256
],
"pastlens"
:
[
512
,
0
,
0
,
256
]}
DECODE_TESTCASES
=
{
"seqlens"
:
[
1
for
_
in
range
(
16
)],
"pastlens"
:
[
50
for
_
in
range
(
4
)]
+
[
100
for
_
in
range
(
4
)]
+
[
200
for
_
in
range
(
4
)]
+
[
400
for
_
in
range
(
4
)],
}
def
get_args
():
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
"Test Operator"
)
parser
.
add_argument
(
"--model_dir"
,
action
=
"store"
,
help
=
"The directory of the model to be tested"
,
)
parser
.
add_argument
(
"--cpu"
,
action
=
"store_true"
,
help
=
"Run cpu test"
,
)
parser
.
add_argument
(
"--nvidia"
,
action
=
"store_true"
,
help
=
"Run nvidia test"
,
)
parser
.
add_argument
(
"--metax"
,
action
=
"store_true"
,
help
=
"Run metax test"
,
)
parser
.
add_argument
(
"--moore"
,
action
=
"store_true"
,
help
=
"Run moore test"
,
)
parser
.
add_argument
(
"--iluvatar"
,
action
=
"store_true"
,
help
=
"Run moore test"
,
)
return
parser
.
parse_args
()
def
create_Qwen3attention_torch
(
dir_path
,
*
,
device
,
dtype
=
torch
.
bfloat16
):
config
=
AutoConfig
.
from_pretrained
(
dir_path
)
config
.
num_hidden_layers
=
1
config
.
_attn_implementation
=
"sdpa"
# --------------------------------------------------------------------------------#
# 创建只包含 attention的模型
# --------------------------------------------------------------------------------#
model
=
qwen3_moe
.
modeling_qwen3_moe
.
Qwen3MoeAttention
(
config
,
layer_idx
=
0
).
to
(
device
=
device
,
dtype
=
dtype
)
tensors
=
{}
for
fname
in
sorted
(
os
.
listdir
(
dir_path
)):
if
not
fname
.
endswith
(
".safetensors"
):
continue
fpath
=
os
.
path
.
join
(
dir_path
,
fname
)
with
safetensors
.
safe_open
(
fpath
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
if
"model.layers.0.self_attn."
in
key
:
tensors
[
key
[
len
(
"model.layers.0.self_attn."
)
:]]
=
f
.
get_tensor
(
key
)
break
model
.
load_state_dict
(
tensors
)
# --------------------------------------------------------------------------------#
# 创建 rotary_emb 类
# --------------------------------------------------------------------------------#
rotary_emb
=
qwen3_moe
.
modeling_qwen3_moe
.
Qwen3MoeRotaryEmbedding
(
config
,
device
=
device
)
return
model
,
rotary_emb
def
generate_attention_input_torch
(
model
,
rotary_emb
,
testcase
,
device
,
dtype
=
torch
.
bfloat16
):
config
=
model
.
config
hidden_size
=
config
.
hidden_size
# 2048
head_dim
=
config
.
head_dim
# 128
num_key_value_heads
=
config
.
num_key_value_heads
bs
=
1
req_list
=
[]
for
seq_lens
,
past_lens
in
zip
(
testcase
[
"seqlens"
],
testcase
[
"pastlens"
]):
hidden_states
=
torch
.
rand
(
(
bs
,
seq_lens
,
hidden_size
),
device
=
device
,
dtype
=
dtype
)
attention_mask
=
None
past_key_values
=
DynamicCache
(
config
=
config
)
key_states
=
torch
.
rand
(
(
bs
,
num_key_value_heads
,
past_lens
,
head_dim
),
device
=
device
,
dtype
=
dtype
)
value_states
=
torch
.
rand
(
(
bs
,
num_key_value_heads
,
past_lens
,
head_dim
),
device
=
device
,
dtype
=
dtype
)
past_key_values
.
update
(
key_states
,
value_states
,
0
)
req
=
{
"hidden_states"
:
hidden_states
,
"attention_mask"
:
attention_mask
,
"past_key_values"
:
past_key_values
,
}
req_list
.
append
(
req
)
return
req_list
def
benchmark_Qwen3attention_prefill_torch
(
model
,
rotary_emb
,
req_list
,
test_cases
):
"""
Test Qwen3attention.
"""
req_out_list
=
[]
for
req
in
req_list
:
# ----------------------------------------- #
# 获得每个req的数据
# ----------------------------------------- #
hidden_states
=
req
[
"hidden_states"
]
attention_mask
=
req
[
"attention_mask"
]
past_key_values
=
req
[
"past_key_values"
]
# ----------------------------------------- #
# 计算当前所需的sin_table,sin_table
# ----------------------------------------- #
cache_lens
=
past_key_values
.
get_seq_length
()
# kv cache 现在的长度
bs
,
seq_len
,
_
=
hidden_states
.
shape
position_ids
=
torch
.
arange
(
cache_lens
,
cache_lens
+
seq_len
,
dtype
=
torch
.
int64
,
device
=
device
).
reshape
((
bs
,
seq_len
))
cos_table
,
sin_table
=
rotary_emb
(
hidden_states
,
position_ids
)
position_embeddings
=
(
sin_table
,
cos_table
)
# ----------------------------------------- #
# 计算一次
# ----------------------------------------- #
output_device
,
_
=
model
(
hidden_states
=
hidden_states
,
position_embeddings
=
position_embeddings
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
)
# ----------------------------------------- #
# 得到结果,存储下来
# ----------------------------------------- #
output_host
=
output_device
.
to
(
"cpu"
)
req_out_list
.
append
(
output_host
)
torch
.
cuda
.
synchronize
()
for
_
in
range
(
WARMUPS
):
for
req
in
req_list
:
# ----------------------------------------- #
# 获得每个req的数据
# ----------------------------------------- #
hidden_states
=
req
[
"hidden_states"
]
attention_mask
=
req
[
"attention_mask"
]
past_key_values
=
req
[
"past_key_values"
]
# ----------------------------------------- #
# 计算当前所需的sin_table,sin_table
# ----------------------------------------- #
cache_lens
=
past_key_values
.
get_seq_length
()
# kv cache 现在的长度
bs
,
seq_len
,
_
=
hidden_states
.
shape
position_ids
=
torch
.
arange
(
cache_lens
,
cache_lens
+
seq_len
,
dtype
=
torch
.
int64
,
device
=
device
).
reshape
((
bs
,
seq_len
))
cos_table
,
sin_table
=
rotary_emb
(
hidden_states
,
position_ids
)
position_embeddings
=
(
sin_table
,
cos_table
)
# ----------------------------------------- #
# 计算一次
# ----------------------------------------- #
output_device
,
_
=
model
(
hidden_states
,
position_embeddings
=
position_embeddings
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
)
time_consuming
=
0
for
_
in
range
(
RUNS
):
for
i
,
req
in
enumerate
(
req_list
):
# ----------------------------------------- #
# 恢复 kv chche的长度
# ----------------------------------------- #
origin_len
=
test_cases
[
"pastlens"
][
i
]
req
[
"past_key_values"
].
crop
(
origin_len
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
# ----------------------------------------- #
# 获得每个req的数据
# ----------------------------------------- #
hidden_states
=
req
[
"hidden_states"
]
attention_mask
=
req
[
"attention_mask"
]
past_key_values
=
req
[
"past_key_values"
]
# ----------------------------------------- #
# 计算当前所需的sin_table,sin_table
# ----------------------------------------- #
cache_lens
=
past_key_values
.
get_seq_length
()
# kv cache 现在的长度
bs
,
seq_len
,
_
=
hidden_states
.
shape
position_ids
=
torch
.
arange
(
cache_lens
,
cache_lens
+
seq_len
,
dtype
=
torch
.
int64
,
device
=
device
).
reshape
((
bs
,
seq_len
))
cos_table
,
sin_table
=
rotary_emb
(
hidden_states
,
position_ids
)
position_embeddings
=
(
sin_table
,
cos_table
)
# ----------------------------------------- #
# 计算一次
# ----------------------------------------- #
output_device
,
_
=
model
(
hidden_states
,
position_embeddings
=
position_embeddings
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
time_consuming
+=
(
end_time
-
start_time
)
*
1000
out_token_count
=
RUNS
*
len
(
req_list
)
latency
=
time_consuming
/
out_token_count
print
(
f
"
\t
WARMUPS=
{
WARMUPS
}
RUNS=
{
RUNS
}
, Attention Torch, average latency:
{
round
(
latency
,
2
)
}
ms
\n
"
)
return
req_out_list
def
benchmark_Qwen3attention_decode_torch
(
model
,
rotary_emb
,
req_list
,
test_cases
):
"""
Test Qwen3attention_decode.
"""
req_out_list
=
[]
for
req
in
req_list
:
# ----------------------------------------- #
# 获得每个req的数据
# ----------------------------------------- #
hidden_states
=
req
[
"hidden_states"
]
attention_mask
=
req
[
"attention_mask"
]
past_key_values
=
req
[
"past_key_values"
]
# ----------------------------------------- #
# 计算当前所需的sin_table,sin_table
# ----------------------------------------- #
cache_lens
=
past_key_values
.
get_seq_length
()
# kv cache 现在的长度
bs
,
seq_len
,
_
=
hidden_states
.
shape
position_ids
=
torch
.
arange
(
cache_lens
,
cache_lens
+
seq_len
,
dtype
=
torch
.
int64
,
device
=
device
).
reshape
((
bs
,
seq_len
))
cos_table
,
sin_table
=
rotary_emb
(
hidden_states
,
position_ids
)
position_embeddings
=
(
sin_table
,
cos_table
)
##
output_device
,
_
=
model
(
hidden_states
=
hidden_states
,
position_embeddings
=
position_embeddings
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
)
output_host
=
output_device
.
to
(
"cpu"
)
req_out_list
.
append
(
output_host
)
torch
.
cuda
.
synchronize
()
for
req
in
req_list
:
for
_
in
range
(
WARMUPS
):
hidden_states
=
req
[
"hidden_states"
]
attention_mask
=
req
[
"attention_mask"
]
past_key_values
=
req
[
"past_key_values"
]
# ----------------------------------------- #
# 计算当前所需的sin_table,sin_table
# ----------------------------------------- #
cache_lens
=
past_key_values
.
get_seq_length
()
# kv cache 现在的长度
bs
,
seq_len
,
_
=
hidden_states
.
shape
position_ids
=
torch
.
arange
(
cache_lens
,
cache_lens
+
seq_len
,
dtype
=
torch
.
int64
,
device
=
device
).
reshape
((
bs
,
seq_len
))
cos_table
,
sin_table
=
rotary_emb
(
hidden_states
,
position_ids
)
position_embeddings
=
(
sin_table
,
cos_table
)
# ----------------------------------------- #
# 计算一次
# ----------------------------------------- #
output_device
,
_
=
model
(
hidden_states
,
position_embeddings
=
position_embeddings
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
)
# ----------------------------------------- #
# 恢复 kv chche的长度
# ----------------------------------------- #
for
i
,
req
in
enumerate
(
req_list
):
origin_len
=
test_cases
[
"pastlens"
][
i
]
req
[
"past_key_values"
].
crop
(
origin_len
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
,
req
in
enumerate
(
req_list
):
for
_
in
range
(
RUNS
):
# ----------------------------------------- #
# 获得每个req的数据
# ----------------------------------------- #
hidden_states
=
req
[
"hidden_states"
]
attention_mask
=
req
[
"attention_mask"
]
past_key_values
=
req
[
"past_key_values"
]
# -------------------------------------------------------------- #
# 计算当前所需的sin_table,sin_table
# -------------------------------------------------------------- #
cache_lens
=
past_key_values
.
get_seq_length
()
# kv cache 现在的长度
bs
,
seq_len
,
_
=
hidden_states
.
shape
position_ids
=
torch
.
arange
(
cache_lens
,
cache_lens
+
seq_len
,
dtype
=
torch
.
int64
,
device
=
device
).
reshape
((
bs
,
seq_len
))
cos_table
,
sin_table
=
rotary_emb
(
hidden_states
,
position_ids
)
position_embeddings
=
(
sin_table
,
cos_table
)
# -------------------------------------------------------------- #
# 计算一次
# -------------------------------------------------------------- #
output_device
,
_
=
model
(
hidden_states
,
position_embeddings
=
position_embeddings
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
)
# -------------------------------------------------------------- #
# 更新hidden_states, ( DynamicCache的类自动更新)
# -------------------------------------------------------------- #
req
[
"hidden_states"
]
=
output_device
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
time_consuming
=
end_time
-
start_time
out_token_count
=
RUNS
*
len
(
req_list
)
throughput
=
out_token_count
/
time_consuming
print
(
f
"
\t
WARMUPS=
{
WARMUPS
}
RUNS=
{
RUNS
}
Attention Torch average throughput:
{
round
(
throughput
,
2
)
}
/s
\n
"
)
return
req_out_list
if
__name__
==
"__main__"
:
args
=
get_args
()
print
(
args
)
model_dir
=
args
.
model_dir
dtype
=
torch
.
bfloat16
# Parse command line arguments
device
=
"cpu"
if
args
.
cpu
:
device
=
"cpu"
elif
args
.
nvidia
:
device
=
"cuda"
elif
args
.
metax
:
device
=
"cuda"
elif
args
.
moore
:
device
=
"musa"
elif
args
.
iluvatar
:
device
=
"cuda"
else
:
print
(
"Usage: python test/qwen3_atteniton_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_dir=<path/to/model_dir>"
)
sys
.
exit
(
1
)
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
model
,
rotary_emb
=
create_Qwen3attention_torch
(
model_dir
,
device
=
device
,
dtype
=
dtype
)
print
(
"
\n
"
)
print
(
"*"
*
130
)
print
(
"Test Qwen3attention "
)
print
(
"*"
*
130
)
print
(
f
"Test Case PREFILL_TESTCASES :
{
PREFILL_TESTCASES
}
"
)
req_list
=
generate_attention_input_torch
(
model
,
rotary_emb
,
PREFILL_TESTCASES
,
device
,
dtype
=
dtype
)
output_prefill
=
benchmark_Qwen3attention_prefill_torch
(
model
,
rotary_emb
,
req_list
,
PREFILL_TESTCASES
)
print
(
"
\n
"
)
print
(
"-"
*
130
)
print
(
f
"
\n
Test DECODE_TESTCASES:
{
DECODE_TESTCASES
}
"
)
#
req_list
=
generate_attention_input_torch
(
model
,
rotary_emb
,
DECODE_TESTCASES
,
device
,
dtype
=
dtype
)
output_decode
=
benchmark_Qwen3attention_decode_torch
(
model
,
rotary_emb
,
req_list
,
DECODE_TESTCASES
)
# clean up device memory
del
model
torch
.
cuda
.
empty_cache
()
test/qwen3_moe_test.py
0 → 100755
View file @
d80bd14a
import
time
import
torch
import
transformers
import
safetensors
import
os
from
transformers
import
AutoConfig
from
transformers.models
import
qwen3_moe
import
sys
WARMUPS
=
10
RUNS
=
100
PREFILL_TESTCASES
=
{
"seqlens"
:
[
64
,
128
,
256
,
256
],
"pastlens"
:
[
512
,
0
,
0
,
256
]}
DECODE_TESTCASES
=
{
"seqlens"
:
[
1
for
_
in
range
(
16
)],
"pastlens"
:
[
50
for
_
in
range
(
4
)]
+
[
100
for
_
in
range
(
4
)]
+
[
200
for
_
in
range
(
4
)]
+
[
400
for
_
in
range
(
4
)],
}
def
get_args
():
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
"Test Operator"
)
parser
.
add_argument
(
"--model_dir"
,
action
=
"store"
,
help
=
"The directory of the model to be tested"
,
)
parser
.
add_argument
(
"--cpu"
,
action
=
"store_true"
,
help
=
"Run cpu test"
,
)
parser
.
add_argument
(
"--nvidia"
,
action
=
"store_true"
,
help
=
"Run nvidia test"
,
)
parser
.
add_argument
(
"--metax"
,
action
=
"store_true"
,
help
=
"Run metax test"
,
)
parser
.
add_argument
(
"--moore"
,
action
=
"store_true"
,
help
=
"Run moore test"
,
)
parser
.
add_argument
(
"--iluvatar"
,
action
=
"store_true"
,
help
=
"Run moore test"
,
)
return
parser
.
parse_args
()
def
create_moe_torch
(
dir_path
,
device
,
dtype
=
torch
.
bfloat16
):
config
=
AutoConfig
.
from_pretrained
(
dir_path
)
moe
=
qwen3_moe
.
modeling_qwen3_moe
.
Qwen3MoeSparseMoeBlock
(
config
).
to
(
device
=
device
,
dtype
=
dtype
)
tensors
=
{}
for
fname
in
sorted
(
os
.
listdir
(
dir_path
)):
if
not
fname
.
endswith
(
".safetensors"
):
continue
fpath
=
os
.
path
.
join
(
dir_path
,
fname
)
with
safetensors
.
safe_open
(
fpath
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
if
"model.layers.0.mlp."
in
key
:
tensors
[
key
[
len
(
"model.layers.0.mlp."
)
:]]
=
f
.
get_tensor
(
key
)
break
moe
.
load_state_dict
(
tensors
)
return
moe
def
generate_moe_input_torch
(
testcase
,
dtype
=
torch
.
bfloat16
):
total_seqlen
=
sum
(
testcase
[
"seqlens"
])
input_tensor
=
torch
.
rand
((
1
,
total_seqlen
,
2048
),
device
=
"cpu"
,
dtype
=
dtype
)
return
input_tensor
def
benchmark_moe_torch
(
moe
,
input_host
,
device
,
dtype
):
""""""
input_device
=
input_host
.
to
(
device
=
device
)
output_device
,
_
=
moe
(
input_device
)
output_host
=
output_device
.
to
(
"cpu"
)
for
_
in
range
(
WARMUPS
):
moe
(
input_device
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
_
in
range
(
RUNS
):
moe
(
input_device
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
print
(
f
" MoE Torch average latency:
{
(
end_time
-
start_time
)
*
1000
/
RUNS
}
ms"
)
return
output_host
if
__name__
==
"__main__"
:
args
=
get_args
()
print
(
args
)
model_dir
=
args
.
model_dir
dtype
=
torch
.
bfloat16
# Parse command line arguments
device
=
"cpu"
if
args
.
cpu
:
device
=
"cpu"
elif
args
.
nvidia
:
device
=
"cuda"
elif
args
.
metax
:
device
=
"cuda"
elif
args
.
moore
:
device
=
"musa"
elif
args
.
iluvatar
:
device
=
"cuda"
else
:
print
(
"Usage: python test/qwen3_atteniton_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_dir=<path/to/model_dir>"
)
sys
.
exit
(
1
)
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
moe
=
create_moe_torch
(
model_dir
,
device
=
device
,
dtype
=
dtype
)
print
(
"*"
*
130
)
print
(
"Test Qwen3 MoE"
)
print
(
"*"
*
130
)
print
(
f
"Test Case PREFILL_TESTCASES :
{
PREFILL_TESTCASES
}
"
)
input_prefill
=
generate_moe_input_torch
(
PREFILL_TESTCASES
)
output_prefill
=
benchmark_moe_torch
(
moe
,
input_prefill
,
device
=
device
,
dtype
=
dtype
)
print
(
"
\n
"
)
print
(
"-"
*
130
)
print
(
f
"
\n
Test DECODE_TESTCASES:
{
DECODE_TESTCASES
}
"
)
input_decode
=
generate_moe_input_torch
(
DECODE_TESTCASES
)
output_decode
=
benchmark_moe_torch
(
moe
,
input_decode
,
device
=
device
,
dtype
=
dtype
)
# clean up device memory
del
moe
torch
.
cuda
.
empty_cache
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment