Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
57201a6a
Unverified
Commit
57201a6a
authored
Nov 10, 2025
by
Xin Yang
Committed by
GitHub
Nov 10, 2025
Browse files
Fix rotary embedding benchmark script (#28323)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
f2d9ad06
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
90 deletions
+64
-90
benchmarks/kernels/benchmark_rope.py
benchmarks/kernels/benchmark_rope.py
+64
-90
No files found.
benchmarks/kernels/benchmark_rope.py
View file @
57201a6a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
itertools
import
accumulate
import
itertools
import
nvtx
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
,
get_rope
from
vllm.
platforms
import
current_platform
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.
triton_utils
import
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
8
,
2
)]
seq_len_range
=
[
2
**
i
for
i
in
range
(
6
,
10
,
1
)]
num_heads_range
=
[
32
,
48
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
num_heads_range
))
def
benchmark_rope_kernels_multi_lora
(
is_neox_style
:
bool
,
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
,
rotary_dim
:
int
|
None
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
max_position
:
int
=
8192
,
base
:
float
=
10000
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
rotary_dim
=
head_size
# silulating serving 4 LoRAs
scaling_factors
=
[
1
,
2
,
4
,
8
]
# batched RoPE can take multiple scaling factors
batched_rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
tuple
(
scaling_factors
)},
)
# non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior
non_batched_ropes
:
list
[
RotaryEmbedding
]
=
[]
for
scaling_factor
in
scaling_factors
:
non_batched_ropes
.
append
(
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
(
scaling_factor
,)},
def
get_benchmark
(
head_size
,
rotary_dim
,
is_neox_style
,
device
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
,
"num_heads"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"torch"
,
"flashinfer"
,
"vllm"
],
line_names
=
[
"PyTorch"
,
"FlashInfer"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
f
"rope-perf
{
'-neox-style'
if
is_neox_style
else
''
}
"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
seq_len
,
num_heads
,
provider
):
dtype
=
torch
.
bfloat16
max_position
=
8192
base
=
10000
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
)
rope
=
rope
.
to
(
dtype
=
dtype
,
device
=
device
)
cos_sin_cache
=
rope
.
cos_sin_cache
.
to
(
dtype
=
torch
.
float
,
device
=
device
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
device
=
device
)
query
=
torch
.
randn
(
(
batch_size
,
seq_len
,
num_heads
*
head_size
),
dtype
=
dtype
,
device
=
device
)
key
=
torch
.
randn_like
(
query
)
# create query offsets for batched RoPE, we concat multiple kv cache
# together and each query needs to find the right kv cache of its type
offset_map
=
torch
.
tensor
(
list
(
accumulate
(
[
0
]
+
[
max_position
*
scaling_factor
*
2
for
scaling_factor
in
scaling_factors
[:
-
1
]
]
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rope
.
forward_native
(
positions
,
query
.
clone
(),
key
.
clone
()),
quantiles
=
quantiles
,
)
elif
provider
==
"flashinfer"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
ops
.
vllm
.
flashinfer_rotary_embedding
(
positions
,
query
.
clone
(),
key
.
clone
(),
head_size
,
cos_sin_cache
,
is_neox_style
,
),
quantiles
=
quantiles
,
)
query_types
=
torch
.
randint
(
0
,
len
(
scaling_factors
),
(
batch_size
,
seq_len
),
device
=
device
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rope
.
forward_cuda
(
positions
,
query
.
clone
(),
key
.
clone
()),
quantiles
=
quantiles
,
)
# map query types to offsets
query_offsets
=
offset_map
[
query_types
]
# the kernel takes flattened offsets
flatten_offsets
=
query_offsets
.
flatten
()
# batched queries of the same type together for non-batched RoPE
queries
=
[
query
[
query_types
==
i
]
for
i
in
range
(
len
(
scaling_factors
))]
keys
=
[
key
[
query_types
==
i
]
for
i
in
range
(
len
(
scaling_factors
))]
packed_qkr
=
zip
(
queries
,
keys
,
non_batched_ropes
)
# synchronize before start timing
torch
.
cuda
.
synchronize
()
with
nvtx
.
annotate
(
"non-batched"
,
color
=
"yellow"
):
for
q
,
k
,
r
in
packed_qkr
:
r
.
forward
(
positions
,
q
,
k
)
torch
.
cuda
.
synchronize
()
with
nvtx
.
annotate
(
"batched"
,
color
=
"green"
):
batched_rope
.
forward
(
positions
,
query
,
key
,
flatten_offsets
)
torch
.
cuda
.
synchronize
()
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
...
...
@@ -116,17 +95,12 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--device"
,
type
=
str
,
choices
=
[
"cuda:0"
,
"cuda:1"
],
default
=
"cuda:0"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/rope/"
)
args
=
parser
.
parse_args
()
print
(
args
)
benchmark_rope_kernels_multi_lora
(
is_neox_style
=
args
.
is_neox_style
,
batch_size
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
num_heads
=
args
.
num_heads
,
head_size
=
args
.
head_size
,
rotary_dim
=
args
.
rotary_dim
,
dtype
=
getattr
(
torch
,
args
.
dtype
),
seed
=
args
.
seed
,
device
=
args
.
device
,
# Get the benchmark function
benchmark
=
get_benchmark
(
args
.
head_size
,
args
.
rotary_dim
,
args
.
is_neox_style
,
args
.
device
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment