Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3ed46f37
Unverified
Commit
3ed46f37
authored
Mar 14, 2026
by
Santino Ramos
Committed by
GitHub
Mar 14, 2026
Browse files
[Model Runner V2] Add Support for XD-RoPE (#36817)
Signed-off-by:
Santino Ramos
<
elsantinoramos@gmail.com
>
parent
84868e47
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
224 additions
and
29 deletions
+224
-29
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+3
-0
vllm/v1/worker/gpu/mm/rope.py
vllm/v1/worker/gpu/mm/rope.py
+197
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+2
-1
vllm/v1/worker/gpu/model_states/default.py
vllm/v1/worker/gpu/model_states/default.py
+22
-28
No files found.
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
3ed46f37
...
...
@@ -320,6 +320,9 @@ class ModelCudaGraphManager(CudaGraphManager):
model_inputs
=
{
"input_ids"
:
input_buffers
.
input_ids
[:
num_tokens
],
"positions"
:
input_buffers
.
positions
[:
num_tokens
],
# TODO: Pass intermediate_tensors for PP CUDA graph
# support (https://github.com/vllm-project/vllm/pull/35162).
"intermediate_tensors"
:
None
,
**
model_state
.
prepare_dummy_inputs
(
num_reqs
,
num_tokens
),
}
model_output
=
model
(
**
model_inputs
)
...
...
vllm/v1/worker/gpu/mm/
m
rope
_utils
.py
→
vllm/v1/worker/gpu/mm/rope.py
View file @
3ed46f37
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
cast
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.models.interfaces
import
SupportsMRoPE
from
vllm.config
import
ModelConfig
from
vllm.model_executor.models.interfaces
import
SupportsMRoPE
,
SupportsXDRoPE
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTensor
class
MRopeState
:
class
RopeState
:
"""Unified state for multi-dimensional RoPE variants (M-RoPE, XD-RoPE).
M-RoPE: 3 dims, uses position delta for decode.
XD-RoPE: 3 or 4 dims, delta is 0 (decode uses orig_pos for all dims).
NOTE: `positions` is implemented with one additional dummy position on
purpose to make it non-contiguous so that it can work with torch compile.
See detailed explanation in
https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
NOTE: When M-RoPE is enabled, position ids are 3D regardless of the
modality of inputs. For text-only inputs, each dimension has identical
position IDs, making M-RoPE functionally equivalent to 1D-RoPE.
See page 5 of https://arxiv.org/abs/2409.12191
"""
def
__init__
(
self
,
num_dims
:
int
,
has_delta
:
bool
,
max_num_reqs
:
int
,
max_num_tokens
:
int
,
max_model_len
:
int
,
device
:
torch
.
device
,
):
self
.
num_dims
=
num_dims
self
.
has_delta
=
has_delta
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_tokens
=
max_num_tokens
self
.
max_model_len
=
max_model_len
...
...
@@ -22,47 +46,51 @@ class MRopeState:
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# wasting a lot of CPU memory.
self
.
prefill_
mrope_
positions
=
StagedWriteTensor
(
(
max_num_reqs
*
3
,
max_model_len
),
self
.
prefill_positions
=
StagedWriteTensor
(
(
max_num_reqs
*
num_dims
,
max_model_len
),
dtype
=
torch
.
int32
,
device
=
device
,
uva_instead_of_gpu
=
True
,
)
self
.
prefill_mrope_delta
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
int32
)
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self
.
mrope_positions
=
torch
.
zeros
(
(
3
,
max_num_tokens
+
1
),
dtype
=
torch
.
int64
,
device
=
device
self
.
positions
=
torch
.
zeros
(
(
num_dims
,
max_num_tokens
+
1
),
dtype
=
torch
.
int64
,
device
=
device
)
def
init_prefill_mrope_positions
(
# Delta is non-zero for M-RoPE, always 0 for XD-RoPE.
self
.
prefill_delta
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
int32
)
def
init_prefill_positions
(
self
,
req_idx
:
int
,
m
rope_model
:
SupportsMRoPE
,
m
odel
:
nn
.
Module
,
prefill_token_ids
:
list
[
int
],
mm_features
:
list
,
)
->
None
:
prefill_mrope_positions
,
prefill_mrope_delta
=
(
mrope_model
.
get_mrope_input_positions
(
prefill_token_ids
,
mm_features
)
)
for
i
in
range
(
3
):
pos
=
prefill_mrope_positions
[
i
].
tolist
()
self
.
prefill_mrope_positions
.
stage_write
(
3
*
req_idx
+
i
,
0
,
pos
)
self
.
prefill_mrope_delta
.
np
[
req_idx
]
=
prefill_mrope_delta
if
self
.
has_delta
:
mrope_model
=
cast
(
SupportsMRoPE
,
model
)
prefill_positions
,
delta
=
mrope_model
.
get_mrope_input_positions
(
prefill_token_ids
,
mm_features
)
self
.
prefill_delta
.
np
[
req_idx
]
=
delta
else
:
xdrope_model
=
cast
(
SupportsXDRoPE
,
model
)
prefill_positions
=
xdrope_model
.
get_xdrope_input_positions
(
prefill_token_ids
,
mm_features
)
for
i
in
range
(
self
.
num_dims
):
pos
=
prefill_positions
[
i
].
tolist
()
self
.
prefill_positions
.
stage_write
(
self
.
num_dims
*
req_idx
+
i
,
0
,
pos
)
def
apply_staged_writes
(
self
)
->
None
:
self
.
prefill_mrope_positions
.
apply_write
()
self
.
prefill_mrope_delta
.
copy_to_uva
()
self
.
prefill_positions
.
apply_write
()
if
self
.
has_delta
:
self
.
prefill_delta
.
copy_to_uva
()
def
get_positions
(
self
,
num_tokens
:
int
)
->
torch
.
Tensor
:
return
self
.
positions
[:,
:
num_tokens
]
def
prepare_
mrope_
positions
(
def
prepare_positions
(
self
,
idx_mapping
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
...
...
@@ -70,34 +98,68 @@ class MRopeState:
num_computed_tokens
:
torch
.
Tensor
,
)
->
None
:
num_reqs
=
idx_mapping
.
shape
[
0
]
_prepare_
m
rope_positions_kernel
[(
num_reqs
,)](
self
.
mrope_
positions
,
self
.
mrope_
positions
.
stride
(
0
),
self
.
prefill_
mrope_
positions
.
gpu
,
3
*
self
.
max_model_len
,
_prepare_rope_positions_kernel
[(
num_reqs
,)](
self
.
positions
,
self
.
positions
.
stride
(
0
),
self
.
prefill_positions
.
gpu
,
self
.
num_dims
*
self
.
max_model_len
,
self
.
max_model_len
,
self
.
prefill_
mrope_
delta
.
gpu
,
self
.
prefill_delta
.
gpu
,
idx_mapping
,
query_start_loc
,
prefill_lens
,
num_computed_tokens
,
BLOCK_SIZE
=
1024
,
NUM_DIMS
=
self
.
num_dims
,
)
def
get_rope_state
(
model_config
:
ModelConfig
,
model
:
nn
.
Module
,
max_num_reqs
:
int
,
max_num_tokens
:
int
,
max_model_len
:
int
,
device
:
torch
.
device
,
)
->
RopeState
|
None
:
"""Create a RopeState if the model uses multi-dimensional RoPE."""
if
model_config
.
uses_mrope
:
assert
isinstance
(
model
,
SupportsMRoPE
)
return
RopeState
(
num_dims
=
3
,
has_delta
=
True
,
max_num_reqs
=
max_num_reqs
,
max_num_tokens
=
max_num_tokens
,
max_model_len
=
max_model_len
,
device
=
device
,
)
if
model_config
.
uses_xdrope_dim
>
0
:
assert
isinstance
(
model
,
SupportsXDRoPE
)
return
RopeState
(
num_dims
=
model_config
.
uses_xdrope_dim
,
has_delta
=
False
,
max_num_reqs
=
max_num_reqs
,
max_num_tokens
=
max_num_tokens
,
max_model_len
=
max_model_len
,
device
=
device
,
)
return
None
@
triton
.
jit
def
_prepare_
m
rope_positions_kernel
(
mrope_
positions_ptr
,
mrope_
positions_stride
,
prefill_
mrope_
positions_ptr
,
prefill_
mrope_
positions_stride0
,
prefill_
mrope_
positions_stride1
,
prefill_
mrope_
delta_ptr
,
def
_prepare_rope_positions_kernel
(
positions_ptr
,
positions_stride
,
prefill_positions_ptr
,
prefill_positions_stride0
,
prefill_positions_stride1
,
prefill_delta_ptr
,
idx_mapping_ptr
,
query_start_loc_ptr
,
prefill_lens_ptr
,
num_computed_tokens_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
NUM_DIMS
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
...
...
@@ -110,27 +172,26 @@ def _prepare_mrope_positions_kernel(
query_end
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
+
1
)
query_len
=
query_end
-
query_start
mrope_delta
=
tl
.
load
(
prefill_mrope_delta_ptr
+
req_state_idx
)
delta
=
tl
.
load
(
prefill_delta_ptr
+
req_state_idx
)
for
i
in
range
(
0
,
query_len
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
query_len
orig_pos
=
num_computed
+
block
for
j
in
tl
.
static_range
(
3
):
for
j
in
tl
.
static_range
(
NUM_DIMS
):
if
is_prefill
:
# Read from pre-computed M-RoPE positions.
pos
=
tl
.
load
(
prefill_
mrope_
positions_ptr
+
req_state_idx
*
prefill_
mrope_
positions_stride0
+
j
*
prefill_
mrope_
positions_stride1
prefill_positions_ptr
+
req_state_idx
*
prefill_positions_stride0
+
j
*
prefill_positions_stride1
+
orig_pos
,
mask
=
mask
,
)
else
:
# Apply M-RoPE delta.
pos
=
orig_pos
+
mrope_delta
pos
=
orig_pos
+
delta
tl
.
store
(
mrope_
positions_ptr
+
j
*
mrope_
positions_stride
+
query_start
+
block
,
positions_ptr
+
j
*
positions_stride
+
query_start
+
block
,
pos
,
mask
=
mask
,
)
vllm/v1/worker/gpu/model_runner.py
View file @
3ed46f37
...
...
@@ -992,6 +992,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"input_ids"
:
input_batch
.
input_ids
,
"positions"
:
input_batch
.
positions
,
"inputs_embeds"
:
inputs_embeds
,
"intermediate_tensors"
:
intermediate_tensors
,
# NOTE: Values returned by `prepare_inputs` will override the default
# values above.
**
self
.
model_state
.
prepare_inputs
(
input_batch
,
self
.
req_states
),
...
...
@@ -1000,7 +1001,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update for non-first PP ranks.
model_inputs
[
"input_ids"
]
=
None
model_inputs
[
"inputs_embeds"
]
=
None
model_inputs
[
"
intermediate_tensors
"
]
=
intermediate_tensors
assert
intermediate_tensors
is
not
None
# Run model.
if
batch_desc
.
cg_mode
==
CUDAGraphMode
.
FULL
:
...
...
vllm/v1/worker/gpu/model_states/default.py
View file @
3ed46f37
...
...
@@ -13,7 +13,7 @@ from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.mm.encoder_cache
import
EncoderCache
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.
m
rope
_utils
import
MR
ope
S
tate
from
vllm.v1.worker.gpu.mm.rope
import
get_r
ope
_s
tate
from
vllm.v1.worker.gpu.model_states.interface
import
ModelState
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.utils
import
AttentionGroup
...
...
@@ -52,29 +52,28 @@ class DefaultModelState(ModelState):
device
=
self
.
device
,
)
self
.
uses_mrope
=
self
.
model_config
.
uses_mrope
if
self
.
uses_mrope
:
self
.
mrope_state
=
MRopeState
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_tokens
=
self
.
max_num_tokens
,
max_model_len
=
self
.
max_model_len
,
device
=
self
.
device
,
)
self
.
rope_state
=
get_rope_state
(
self
.
model_config
,
model
,
max_num_reqs
=
self
.
max_num_reqs
,
max_num_tokens
=
self
.
max_num_tokens
,
max_model_len
=
self
.
max_model_len
,
device
=
self
.
device
,
)
def
add_request
(
self
,
req_index
:
int
,
new_req_data
:
NewRequestData
)
->
None
:
if
self
.
uses_mrope
:
# Pre-compute M-RoPE positions for prefill.
if
self
.
rope_state
is
not
None
:
assert
new_req_data
.
prefill_token_ids
is
not
None
self
.
m
rope_state
.
init_prefill_
mrope_
positions
(
self
.
rope_state
.
init_prefill_positions
(
req_index
,
self
.
model
,
# type: ignore
self
.
model
,
new_req_data
.
prefill_token_ids
,
mm_features
=
new_req_data
.
mm_features
,
)
def
apply_staged_writes
(
self
)
->
None
:
if
self
.
uses_mrop
e
:
self
.
m
rope_state
.
apply_staged_writes
()
if
self
.
rope_state
is
not
Non
e
:
self
.
rope_state
.
apply_staged_writes
()
def
get_mm_embeddings
(
self
,
...
...
@@ -109,31 +108,26 @@ class DefaultModelState(ModelState):
def
prepare_inputs
(
self
,
input_batch
:
InputBatch
,
req_states
:
RequestState
)
->
dict
[
str
,
Any
]:
if
not
self
.
uses_mrope
:
# Common case (1D positions).
return
{}
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
if
self
.
rope_state
is
None
:
return
{}
# Common case (1D positions).
# Prepare M-RoPE positions.
self
.
mrope_state
.
prepare_mrope_positions
(
self
.
rope_state
.
prepare_positions
(
input_batch
.
idx_mapping
,
input_batch
.
query_start_loc
,
req_states
.
prefill_len
.
gpu
,
req_states
.
num_computed_tokens
.
gpu
,
)
mrope_positions
=
self
.
mrope_state
.
mrope_positions
[
:,
:
input_batch
.
num_tokens_after_padding
]
return
{
"positions"
:
mrope_positions
}
positions
=
self
.
rope_state
.
get_positions
(
input_batch
.
num_tokens_after_padding
)
return
{
"positions"
:
positions
}
def
prepare_dummy_inputs
(
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
Any
]:
model_inputs
=
{}
if
self
.
supports_mm_inputs
:
inputs_embeds
=
self
.
encoder_runner
.
inputs_embeds
[:
num_tokens
]
model_inputs
[
"inputs_embeds"
]
=
inputs_embeds
if
self
.
uses_mrope
:
mrope_positions
=
self
.
mrope_state
.
mrope_positions
[:,
:
num_tokens
]
model_inputs
[
"positions"
]
=
mrope_positions
if
self
.
rope_state
is
not
None
:
model_inputs
[
"positions"
]
=
self
.
rope_state
.
get_positions
(
num_tokens
)
return
model_inputs
def
prepare_attn
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment