Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9a3b8832
Unverified
Commit
9a3b8832
authored
Jun 24, 2025
by
Vadim Gimpelson
Committed by
GitHub
Jun 23, 2025
Browse files
[PERF] Speedup of MRoPE prepare inputs (#19939)
Signed-off-by:
Vadim Gimpelson
<
vadim.gimpelson@centml.ai
>
parent
3014c920
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
18 deletions
+17
-18
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+9
-9
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+8
-9
No files found.
vllm/model_executor/layers/rotary_embedding.py
View file @
9a3b8832
...
...
@@ -26,6 +26,7 @@
import
math
from
typing
import
Any
,
Optional
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
...
...
@@ -1458,15 +1459,14 @@ class MRotaryEmbedding(RotaryEmbedding):
]
@
staticmethod
def
get_next_input_positions_tensor
(
mrope_position_delta
:
int
,
context_len
:
int
,
seq_len
:
int
,
)
->
torch
.
Tensor
:
return
torch
.
arange
(
mrope_position_delta
+
context_len
,
mrope_position_delta
+
seq_len
,
).
expand
(
3
,
-
1
)
def
get_next_input_positions_tensor
(
out
:
np
.
ndarray
,
out_offset
:
int
,
mrope_position_delta
:
int
,
context_len
:
int
,
num_new_tokens
:
int
):
values
=
np
.
arange
(
mrope_position_delta
+
context_len
,
mrope_position_delta
+
context_len
+
num_new_tokens
,
dtype
=
out
.
dtype
)
out
[:,
out_offset
:
out_offset
+
num_new_tokens
]
=
values
@
classmethod
def
omni_get_updates_use_audio_in_video
(
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
9a3b8832
...
...
@@ -262,6 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
self
.
mrope_positions_np
=
self
.
mrope_positions_cpu
.
numpy
()
# Only relevant for models using ALiBi (e.g, MPT)
self
.
use_alibi
=
check_use_alibi
(
model_config
)
...
...
@@ -889,15 +890,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dst_start
=
mrope_pos_ptr
dst_end
=
mrope_pos_ptr
+
completion_part_len
self
.
mrope_positions_cpu
[:,
dst_start
:
dst_end
]
=
\
MRotaryEmbedding
.
get_next_input_positions_tensor
(
req
.
mrope_position_delta
,
context_len
=
num_computed_tokens
+
prompt_part_len
,
seq_len
=
num_computed_tokens
+
prompt_part_len
+
completion_part_len
,
)
MRotaryEmbedding
.
get_next_input_positions_tensor
(
out
=
self
.
mrope_positions_np
,
out_offset
=
dst_start
,
mrope_position_delta
=
req
.
mrope_position_delta
,
context_len
=
num_computed_tokens
+
prompt_part_len
,
num_new_tokens
=
completion_part_len
,
)
mrope_pos_ptr
+=
completion_part_len
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment