Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0a7dd237
Unverified
Commit
0a7dd237
authored
Jan 12, 2026
by
Woosuk Kwon
Committed by
GitHub
Jan 12, 2026
Browse files
[Model Runner V2] Add support for M-RoPE (#32143)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
dec28688
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
203 additions
and
7 deletions
+203
-7
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+6
-1
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+21
-2
vllm/v1/worker/gpu/mm/__init__.py
vllm/v1/worker/gpu/mm/__init__.py
+0
-0
vllm/v1/worker/gpu/mm/mrope_utils.py
vllm/v1/worker/gpu/mm/mrope_utils.py
+127
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+49
-4
No files found.
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
0a7dd237
...
...
@@ -25,10 +25,12 @@ class CudaGraphManager:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
uses_mrope
:
bool
,
device
:
torch
.
device
,
):
self
.
vllm_config
=
vllm_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
uses_mrope
=
uses_mrope
self
.
device
=
device
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
...
...
@@ -79,7 +81,10 @@ class CudaGraphManager:
)
->
None
:
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
]
if
not
self
.
uses_mrope
:
positions
=
input_buffers
.
positions
[:
num_tokens
]
else
:
positions
=
input_buffers
.
mrope_positions
[:,
:
num_tokens
]
attn_metadata
=
prepare_inputs_to_capture
(
num_reqs
,
num_tokens
,
...
...
vllm/v1/worker/gpu/input_batch.py
View file @
0a7dd237
...
...
@@ -31,6 +31,19 @@ class InputBuffers:
)
self
.
seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self
.
mrope_positions
=
torch
.
zeros
(
(
3
,
max_num_tokens
+
1
),
dtype
=
torch
.
int64
,
device
=
device
)
@
dataclass
class
InputBatch
:
...
...
@@ -62,6 +75,8 @@ class InputBatch:
input_ids
:
torch
.
Tensor
# [num_tokens_after_padding]
positions
:
torch
.
Tensor
# [3, num_tokens_after_padding]
mrope_positions
:
torch
.
Tensor
# layer_name -> Metadata
attn_metadata
:
dict
[
str
,
Any
]
...
...
@@ -107,8 +122,11 @@ class InputBatch:
input_buffers
.
query_start_loc
[
num_reqs
+
1
:]
=
num_tokens
query_start_loc
=
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
].
zero_
()
positions
=
input_buffers
.
positions
[:
num_tokens
].
zero_
()
input_buffers
.
mrope_positions
.
zero_
()
mrope_positions
=
input_buffers
.
mrope_positions
[:,
:
num_tokens
]
# attn_metadata = defaultdict(lambda: None)
logits_indices
=
query_start_loc
[
1
:]
-
1
cu_num_logits
=
torch
.
arange
(
num_reqs
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
...
...
@@ -128,6 +146,7 @@ class InputBatch:
seq_lens
=
seq_lens
,
input_ids
=
input_ids
,
positions
=
positions
,
mrope_positions
=
mrope_positions
,
attn_metadata
=
None
,
# type: ignore
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
...
...
vllm/v1/worker/gpu/mm/__init__.py
0 → 100644
View file @
0a7dd237
vllm/v1/worker/gpu/mm/mrope_utils.py
0 → 100644
View file @
0a7dd237
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.model_executor.models.interfaces
import
SupportsMRoPE
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTensor
class
MRopeState
:
def
__init__
(
self
,
max_num_reqs
:
int
,
max_model_len
:
int
,
device
:
torch
.
device
,
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
self
.
device
=
device
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# wasting a lot of CPU memory.
self
.
prefill_mrope_positions
=
StagedWriteTensor
(
(
max_num_reqs
,
3
*
max_model_len
),
dtype
=
torch
.
int32
,
device
=
device
,
uva_instead_of_gpu
=
True
,
)
self
.
prefill_mrope_delta
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
int32
)
def
init_prefill_mrope_positions
(
self
,
req_idx
:
int
,
mrope_model
:
SupportsMRoPE
,
prefill_token_ids
:
list
[
int
],
mm_features
:
list
,
)
->
None
:
prefill_mrope_positions
,
prefill_mrope_delta
=
(
mrope_model
.
get_mrope_input_positions
(
prefill_token_ids
,
mm_features
,
)
)
for
i
in
range
(
3
):
pos
=
prefill_mrope_positions
[
i
].
tolist
()
self
.
prefill_mrope_positions
.
stage_write
(
req_idx
,
i
*
self
.
max_model_len
,
pos
)
self
.
prefill_mrope_delta
.
np
[
req_idx
]
=
prefill_mrope_delta
def
apply_staged_writes
(
self
)
->
None
:
self
.
prefill_mrope_positions
.
apply_write
()
self
.
prefill_mrope_delta
.
copy_to_uva
()
def
prepare_mrope_positions
(
self
,
idx_mapping
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
prefill_lens
:
torch
.
Tensor
,
num_computed_tokens
:
torch
.
Tensor
,
mrope_positions
:
torch
.
Tensor
,
)
->
None
:
num_reqs
=
idx_mapping
.
shape
[
0
]
_prepare_mrope_positions_kernel
[(
num_reqs
,)](
mrope_positions
,
mrope_positions
.
stride
(
0
),
self
.
prefill_mrope_positions
.
gpu
,
self
.
prefill_mrope_positions
.
gpu
.
stride
(
0
),
self
.
max_model_len
,
self
.
prefill_mrope_delta
.
gpu
,
idx_mapping
,
query_start_loc
,
prefill_lens
,
num_computed_tokens
,
BLOCK_SIZE
=
1024
,
)
@
triton
.
jit
def
_prepare_mrope_positions_kernel
(
mrope_positions_ptr
,
mrope_positions_stride
,
prefill_mrope_positions_ptr
,
prefill_mrope_positions_stride0
,
prefill_mrope_positions_stride1
,
prefill_mrope_delta_ptr
,
idx_mapping_ptr
,
query_start_loc_ptr
,
prefill_lens_ptr
,
num_computed_tokens_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
prefill_len
=
tl
.
load
(
prefill_lens_ptr
+
req_state_idx
)
num_computed
=
tl
.
load
(
num_computed_tokens_ptr
+
req_state_idx
)
is_prefill
=
num_computed
<
prefill_len
query_start
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
+
1
)
query_len
=
query_end
-
query_start
mrope_delta
=
tl
.
load
(
prefill_mrope_delta_ptr
+
req_state_idx
)
for
i
in
range
(
0
,
query_len
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
query_len
orig_pos
=
num_computed
+
block
for
j
in
tl
.
static_range
(
3
):
if
is_prefill
:
# Read from pre-computed M-RoPE positions.
pos
=
tl
.
load
(
prefill_mrope_positions_ptr
+
req_state_idx
*
prefill_mrope_positions_stride0
+
j
*
prefill_mrope_positions_stride1
+
orig_pos
,
mask
=
mask
,
)
else
:
# Apply M-RoPE delta.
pos
=
orig_pos
+
mrope_delta
tl
.
store
(
mrope_positions_ptr
+
j
*
mrope_positions_stride
+
query_start
+
block
,
pos
,
mask
=
mask
,
)
vllm/v1/worker/gpu/model_runner.py
View file @
0a7dd237
...
...
@@ -47,6 +47,7 @@ from vllm.v1.worker.gpu.input_batch import (
prepare_pos_seq_lens
,
prepare_prefill_inputs
,
)
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.sample.logprob
import
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
...
...
@@ -94,6 +95,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
inputs_embeds_size
=
self
.
model_config
.
get_inputs_embeds_size
()
# Multimodal
self
.
uses_mrope
=
self
.
model_config
.
uses_mrope
if
self
.
uses_mrope
:
self
.
mrope_states
=
MRopeState
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
device
=
self
.
device
,
)
self
.
use_async_scheduling
=
self
.
scheduler_config
.
async_scheduling
self
.
output_copy_stream
=
torch
.
cuda
.
Stream
(
self
.
device
)
self
.
output_copy_event
=
torch
.
cuda
.
Event
()
...
...
@@ -132,7 +142,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
sampler
=
Sampler
(
logprobs_mode
=
self
.
model_config
.
logprobs_mode
)
# CUDA graphs.
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
vllm_config
,
self
.
device
)
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
vllm_config
,
self
.
uses_mrope
,
self
.
device
)
# Structured outputs worker.
self
.
structured_outputs_worker
=
StructuredOutputsWorker
(
max_num_logits
=
self
.
max_num_reqs
*
(
self
.
num_speculative_steps
+
1
),
...
...
@@ -268,6 +280,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dp_size
=
self
.
parallel_config
.
data_parallel_size
num_tokens_across_dp
=
make_num_tokens_across_dp
(
dp_size
,
num_tokens
)
num_sampled_tokens
=
np
.
ones
(
input_batch
.
num_reqs
,
dtype
=
np
.
int32
)
if
not
self
.
uses_mrope
:
positions
=
input_batch
.
positions
else
:
positions
=
input_batch
.
mrope_positions
with
(
self
.
maybe_dummy_run_with_lora
(
self
.
lora_config
,
...
...
@@ -283,7 +299,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
):
hidden_states
=
self
.
model
(
input_ids
=
input_batch
.
input_ids
,
positions
=
input_batch
.
positions
,
positions
=
positions
,
)
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
return
hidden_states
,
sample_hidden_states
...
...
@@ -393,8 +409,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_params
=
new_req_data
.
sampling_params
,
lora_request
=
new_req_data
.
lora_request
,
)
req_index
=
self
.
req_states
.
req_id_to_index
[
req_id
]
# Pre-compute M-RoPE positions for prefill.
if
self
.
uses_mrope
:
self
.
mrope_states
.
init_prefill_mrope_positions
(
req_index
,
self
.
model
,
# type: ignore
new_req_data
.
prefill_token_ids
,
mm_features
=
[],
# TODO
)
self
.
block_tables
.
append_block_ids
(
req_index
,
new_req_data
.
block_ids
,
overwrite
=
True
)
...
...
@@ -411,6 +436,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
req_states
.
apply_staged_writes
()
self
.
block_tables
.
apply_staged_writes
()
if
self
.
uses_mrope
:
self
.
mrope_states
.
apply_staged_writes
()
def
prepare_inputs
(
self
,
...
...
@@ -511,6 +538,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
]
# Prepare M-RoPE positions.
if
self
.
uses_mrope
:
self
.
mrope_states
.
prepare_mrope_positions
(
idx_mapping
,
query_start_loc
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
self
.
input_buffers
.
mrope_positions
,
)
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices
=
combine_sampled_and_draft_tokens
(
...
...
@@ -546,6 +583,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
]
positions
=
self
.
input_buffers
.
positions
[:
num_tokens_after_padding
]
mrope_positions
=
self
.
input_buffers
.
mrope_positions
[
:,
:
num_tokens_after_padding
]
return
InputBatch
(
req_ids
=
req_ids
,
num_reqs
=
num_reqs
,
...
...
@@ -561,6 +601,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
seq_lens
=
seq_lens
,
input_ids
=
input_ids
,
positions
=
positions
,
mrope_positions
=
mrope_positions
,
attn_metadata
=
attn_metadata
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
...
...
@@ -889,6 +930,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else
:
# Run PyTorch model in eager mode.
# TODO(woosuk): Support piecewise CUDA graph.
if
not
self
.
uses_mrope
:
positions
=
input_batch
.
positions
else
:
positions
=
input_batch
.
mrope_positions
with
set_forward_context
(
input_batch
.
attn_metadata
,
self
.
vllm_config
,
...
...
@@ -898,7 +943,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
):
hidden_states
=
self
.
model
(
input_ids
=
input_batch
.
input_ids
,
positions
=
input_batch
.
positions
,
positions
=
positions
,
)
self
.
execute_model_state
=
hidden_states
,
input_batch
,
sampling_metadata
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment