Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9dc40d38
Commit
9dc40d38
authored
Mar 24, 2026
by
wangmin6
Browse files
Merge branch 'gy_v015-1dmrope' into 'v0.15.1-dev'
Gy v015 1dmrope See merge request dcutoolkit/deeplearing/vllm!530
parents
06185134
c07d9253
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
83 additions
and
16 deletions
+83
-16
vllm/envs.py
vllm/envs.py
+6
-0
vllm/v1/core/encoder_cache_manager.py
vllm/v1/core/encoder_cache_manager.py
+18
-3
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+59
-13
No files found.
vllm/envs.py
View file @
9dc40d38
...
@@ -157,6 +157,8 @@ if TYPE_CHECKING:
...
@@ -157,6 +157,8 @@ if TYPE_CHECKING:
VLLM_MXFP4_USE_MARLIN
:
bool
|
None
=
None
VLLM_MXFP4_USE_MARLIN
:
bool
|
None
=
None
VLLM_DEEPEPLL_NVFP4_DISPATCH
:
bool
=
False
VLLM_DEEPEPLL_NVFP4_DISPATCH
:
bool
=
False
VLLM_V1_USE_OUTLINES_CACHE
:
bool
=
False
VLLM_V1_USE_OUTLINES_CACHE
:
bool
=
False
VLLM_1D_MROPE
:
bool
=
False
VLLM_ENCODER_CACHE_SIZE
:
int
|
None
=
None
VLLM_TPU_BUCKET_PADDING_GAP
:
int
=
0
VLLM_TPU_BUCKET_PADDING_GAP
:
int
=
0
VLLM_TPU_MOST_MODEL_LEN
:
int
|
None
=
None
VLLM_TPU_MOST_MODEL_LEN
:
int
|
None
=
None
VLLM_TPU_USING_PATHWAYS
:
bool
=
False
VLLM_TPU_USING_PATHWAYS
:
bool
=
False
...
@@ -1925,6 +1927,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1925,6 +1927,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MOE_W16A16_TRITON"
:
"VLLM_USE_MOE_W16A16_TRITON"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MOE_W16A16_TRITON"
,
"0"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MOE_W16A16_TRITON"
,
"0"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
"VLLM_1D_MROPE"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_1D_MROPE"
,
"0"
).
lower
()
in
(
"true"
,
"1"
)),
"VLLM_ENCODER_CACHE_SIZE"
:
lambda
:
maybe_convert_int
(
os
.
environ
.
get
(
"VLLM_ENCODER_CACHE_SIZE"
,
None
)),
#If set to 1/True, enable the V1 fast token-id copy path in InputBatch.
#If set to 1/True, enable the V1 fast token-id copy path in InputBatch.
"VLLM_V1_FAST_TOKEN_ID_COPY"
:
"VLLM_V1_FAST_TOKEN_ID_COPY"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_V1_FAST_TOKEN_ID_COPY"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_V1_FAST_TOKEN_ID_COPY"
,
"False"
).
lower
()
in
...
...
vllm/v1/core/encoder_cache_manager.py
View file @
9dc40d38
...
@@ -5,6 +5,7 @@ from collections import OrderedDict
...
@@ -5,6 +5,7 @@ from collections import OrderedDict
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalRegistry
from
vllm.multimodal
import
MultiModalRegistry
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -15,6 +16,16 @@ if TYPE_CHECKING:
...
@@ -15,6 +16,16 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
_get_encoder_cache_size_override
()
->
int
|
None
:
encoder_cache_size_override
=
envs
.
VLLM_ENCODER_CACHE_SIZE
if
encoder_cache_size_override
is
not
None
:
logger
.
info_once
(
"Using VLLM_ENCODER_CACHE_SIZE=%d to override encoder cache size."
,
encoder_cache_size_override
,
)
return
encoder_cache_size_override
class
EncoderCacheManager
:
class
EncoderCacheManager
:
"""Manages caching of encoder outputs for multimodal models in vLLM V1.
"""Manages caching of encoder outputs for multimodal models in vLLM V1.
...
@@ -342,9 +353,13 @@ def compute_mm_encoder_budget(
...
@@ -342,9 +353,13 @@ def compute_mm_encoder_budget(
encoder_compute_budget
=
max
(
encoder_compute_budget
=
max
(
scheduler_config
.
max_num_encoder_input_tokens
,
max_tokens_per_mm_item
scheduler_config
.
max_num_encoder_input_tokens
,
max_tokens_per_mm_item
)
)
encoder_cache_size
=
max
(
encoder_cache_size_override
=
_get_encoder_cache_size_override
()
scheduler_config
.
encoder_cache_size
,
max_tokens_per_mm_item
if
encoder_cache_size_override
is
not
None
:
)
encoder_cache_size
=
encoder_cache_size_override
else
:
encoder_cache_size
=
max
(
scheduler_config
.
encoder_cache_size
,
max_tokens_per_mm_item
)
return
encoder_compute_budget
,
encoder_cache_size
return
encoder_compute_budget
,
encoder_cache_size
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
9dc40d38
...
@@ -398,6 +398,7 @@ class GPUModelRunner(
...
@@ -398,6 +398,7 @@ class GPUModelRunner(
# Multi-modal data support
# Multi-modal data support
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
uses_mrope
=
model_config
.
uses_mrope
self
.
uses_mrope
=
model_config
.
uses_mrope
self
.
use_1d_mrope
=
self
.
uses_mrope
and
envs
.
VLLM_1D_MROPE
self
.
uses_xdrope_dim
=
model_config
.
uses_xdrope_dim
self
.
uses_xdrope_dim
=
model_config
.
uses_xdrope_dim
self
.
supports_mm_inputs
=
self
.
mm_registry
.
supports_multimodal_inputs
(
self
.
supports_mm_inputs
=
self
.
mm_registry
.
supports_multimodal_inputs
(
model_config
model_config
...
@@ -613,9 +614,14 @@ class GPUModelRunner(
...
@@ -613,9 +614,14 @@ class GPUModelRunner(
# identical position IDs, making M-RoPE functionally equivalent to
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
# See page 5 of https://arxiv.org/abs/2409.12191
self
.
mrope_positions
=
self
.
_make_buffer
(
if
self
.
use_1d_mrope
:
(
3
,
self
.
max_num_tokens
+
1
),
dtype
=
torch
.
int64
self
.
mrope_positions
=
self
.
_make_buffer
(
)
3
*
(
self
.
max_num_tokens
+
1
),
dtype
=
torch
.
int64
)
else
:
self
.
mrope_positions
=
self
.
_make_buffer
(
(
3
,
self
.
max_num_tokens
+
1
),
dtype
=
torch
.
int64
)
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if
self
.
uses_xdrope_dim
>
0
:
if
self
.
uses_xdrope_dim
>
0
:
...
@@ -771,12 +777,18 @@ class GPUModelRunner(
...
@@ -771,12 +777,18 @@ class GPUModelRunner(
def
_get_positions
(
self
,
num_tokens
:
Any
):
def
_get_positions
(
self
,
num_tokens
:
Any
):
if
isinstance
(
num_tokens
,
int
):
if
isinstance
(
num_tokens
,
int
):
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
if
self
.
use_1d_mrope
:
return
self
.
mrope_positions
.
gpu
[:
3
*
num_tokens
].
view
(
num_tokens
,
3
).
T
return
self
.
mrope_positions
.
gpu
[:,
:
num_tokens
]
return
self
.
mrope_positions
.
gpu
[:,
:
num_tokens
]
if
self
.
uses_xdrope_dim
>
0
:
if
self
.
uses_xdrope_dim
>
0
:
return
self
.
xdrope_positions
.
gpu
[:,
:
num_tokens
]
return
self
.
xdrope_positions
.
gpu
[:,
:
num_tokens
]
return
self
.
positions
.
gpu
[:
num_tokens
]
return
self
.
positions
.
gpu
[:
num_tokens
]
else
:
else
:
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
if
self
.
use_1d_mrope
:
return
self
.
mrope_positions
.
gpu
.
view
(
-
1
,
3
)[
num_tokens
].
T
return
self
.
mrope_positions
.
gpu
[:,
num_tokens
]
return
self
.
mrope_positions
.
gpu
[:,
num_tokens
]
if
self
.
uses_xdrope_dim
>
0
:
if
self
.
uses_xdrope_dim
>
0
:
return
self
.
xdrope_positions
.
gpu
[:,
num_tokens
]
return
self
.
xdrope_positions
.
gpu
[:,
num_tokens
]
...
@@ -797,6 +809,13 @@ class GPUModelRunner(
...
@@ -797,6 +809,13 @@ class GPUModelRunner(
def
_copy_mrope_positions_to_gpu
(
self
,
num_tokens
:
int
)
->
None
:
def
_copy_mrope_positions_to_gpu
(
self
,
num_tokens
:
int
)
->
None
:
if
not
self
.
uses_mrope
:
if
not
self
.
uses_mrope
:
return
return
if
self
.
use_1d_mrope
:
num_values
=
3
*
num_tokens
self
.
mrope_positions
.
gpu
[:
num_values
].
copy_
(
self
.
mrope_positions
.
cpu
[:
num_values
],
non_blocking
=
True
,
)
return
self
.
mrope_positions
.
gpu
[:,
:
num_tokens
].
copy_
(
self
.
mrope_positions
.
gpu
[:,
:
num_tokens
].
copy_
(
self
.
mrope_positions
.
cpu
[:,
:
num_tokens
],
self
.
mrope_positions
.
cpu
[:,
:
num_tokens
],
non_blocking
=
True
,
non_blocking
=
True
,
...
@@ -2111,6 +2130,13 @@ class GPUModelRunner(
...
@@ -2111,6 +2130,13 @@ class GPUModelRunner(
def
_calc_mrope_positions
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_calc_mrope_positions
(
self
,
scheduler_output
:
"SchedulerOutput"
):
mrope_pos_ptr
=
0
mrope_pos_ptr
=
0
if
self
.
use_1d_mrope
:
mrope_positions_token_major
=
self
.
mrope_positions
.
cpu
.
view
(
self
.
max_num_tokens
+
1
,
3
)
mrope_positions_token_major_np
=
self
.
mrope_positions
.
np
.
reshape
(
self
.
max_num_tokens
+
1
,
3
)
for
index
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
for
index
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
req
=
self
.
requests
[
req_id
]
req
=
self
.
requests
[
req_id
]
assert
req
.
mrope_positions
is
not
None
assert
req
.
mrope_positions
is
not
None
...
@@ -2137,9 +2163,14 @@ class GPUModelRunner(
...
@@ -2137,9 +2163,14 @@ class GPUModelRunner(
src_start
=
num_computed_tokens
src_start
=
num_computed_tokens
src_end
=
num_computed_tokens
+
prompt_part_len
src_end
=
num_computed_tokens
+
prompt_part_len
self
.
mrope_positions
.
cpu
[:,
dst_start
:
dst_end
]
=
req
.
mrope_positions
[
if
self
.
use_1d_mrope
:
:,
src_start
:
src_end
mrope_positions_token_major
[
dst_start
:
dst_end
,
:].
copy_
(
]
req
.
mrope_positions
[:,
src_start
:
src_end
].
transpose
(
0
,
1
)
)
else
:
self
.
mrope_positions
.
cpu
[:,
dst_start
:
dst_end
]
=
(
req
.
mrope_positions
[:,
src_start
:
src_end
]
)
mrope_pos_ptr
+=
prompt_part_len
mrope_pos_ptr
+=
prompt_part_len
if
completion_part_len
>
0
:
if
completion_part_len
>
0
:
...
@@ -2148,13 +2179,28 @@ class GPUModelRunner(
...
@@ -2148,13 +2179,28 @@ class GPUModelRunner(
dst_end
=
mrope_pos_ptr
+
completion_part_len
dst_end
=
mrope_pos_ptr
+
completion_part_len
assert
req
.
mrope_position_delta
is
not
None
assert
req
.
mrope_position_delta
is
not
None
MRotaryEmbedding
.
get_next_input_positions_tensor
(
if
self
.
use_1d_mrope
:
out
=
self
.
mrope_positions
.
np
,
values
=
np
.
arange
(
out_offset
=
dst_start
,
req
.
mrope_position_delta
mrope_position_delta
=
req
.
mrope_position_delta
,
+
num_computed_tokens
context_len
=
num_computed_tokens
+
prompt_part_len
,
+
prompt_part_len
,
num_new_tokens
=
completion_part_len
,
req
.
mrope_position_delta
)
+
num_computed_tokens
+
prompt_part_len
+
completion_part_len
,
dtype
=
mrope_positions_token_major_np
.
dtype
,
)
mrope_positions_token_major_np
[
dst_start
:
dst_end
,
:]
=
values
[
:,
None
]
else
:
MRotaryEmbedding
.
get_next_input_positions_tensor
(
out
=
self
.
mrope_positions
.
np
,
out_offset
=
dst_start
,
mrope_position_delta
=
req
.
mrope_position_delta
,
context_len
=
num_computed_tokens
+
prompt_part_len
,
num_new_tokens
=
completion_part_len
,
)
mrope_pos_ptr
+=
completion_part_len
mrope_pos_ptr
+=
completion_part_len
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment