Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5efd6905
Unverified
Commit
5efd6905
authored
Aug 20, 2025
by
Cyrus Leung
Committed by
GitHub
Aug 20, 2025
Browse files
[CLI][Doc] Formalize `--mm-encoder-tp-mode` (#23190)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
b17109be
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
104 additions
and
24 deletions
+104
-24
docs/configuration/optimization.md
docs/configuration/optimization.md
+45
-0
vllm/config/__init__.py
vllm/config/__init__.py
+33
-1
vllm/config/parallel.py
vllm/config/parallel.py
+0
-4
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+22
-13
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+2
-2
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+1
-2
vllm/model_executor/models/step3_vl.py
vllm/model_executor/models/step3_vl.py
+1
-2
No files found.
docs/configuration/optimization.md
View file @
5efd6905
...
...
@@ -129,6 +129,51 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
Data parallelism can be combined with the other parallelism strategies and is set by
`data_parallel_size=N`
.
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
### Batch-level DP for Multi-Modal Encoders
By default, TP is used to shard the weights of multi-modal encoders just like for language decoders,
in order to reduce the memory and compute load on each GPU.
However, since the size of multi-modal encoders is very small compared to language decoders,
there is relatively little gain from TP. On the other hand, TP incurs significant communication
overhead because of all-reduce being performed after every layer.
Given this, it may be advantageous to instead shard the batched input data using TP, essentially
performing batch-level DP. This has been shown to improve the throughput by around 10% for
`tensor_parallel_size=8`
. For vision encoders that use hardware-unoptimized Conv3D operations,
batch-level DP can provide another 40% increase to throughput compared to regular TP.
Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank,
there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already.
You can enable batch-level DP by setting
`mm_encoder_tp_mode="data"`
, for example:
```
python
from
vllm
import
LLM
llm
=
LLM
(
model
=
"Qwen/Qwen2.5-VL-72B-Instruct"
,
# Create two EngineCore instances, one per DP rank
data_parallel_size
=
2
,
# Within each EngineCore instance:
# The vision encoder uses TP=4 (not DP=2) to shard the input data
# The language decoder uses TP=4 to shard the weights as usual
tensor_parallel_size
=
4
,
mm_encoder_tp_mode
=
"data"
,
)
```
!! important
Batch-level DP is not to be confused with API request-level DP
(which is instead controlled by
`data_parallel_size`
).
The availablilty of batch-level DP is based on model implementation.
Currently, the following models support
`mm_encoder_tp_mode="data"`
:
-
Llama4 (
<gh-pr:18368>
)
-
Qwen2.5-VL (
<gh-pr:22742>
)
-
Step3 (
<gh-pr:22697>
)
## Input Processing
### Parallel Processing
...
...
vllm/config/__init__.py
View file @
5efd6905
...
...
@@ -258,6 +258,7 @@ TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
ModelDType
=
Literal
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
]
LogprobsMode
=
Literal
[
"raw_logprobs"
,
"raw_logits"
,
"processed_logprobs"
,
"processed_logits"
]
MMEncoderTPMode
=
Literal
[
"weights"
,
"data"
]
@
config
...
...
@@ -438,6 +439,19 @@ class ModelConfig:
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
Set to `0` to disable this cache completely (not recommended)."""
mm_encoder_tp_mode
:
MMEncoderTPMode
=
"weights"
"""Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
override_neuron_config
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
...
...
@@ -856,8 +870,10 @@ class ModelConfig:
media_io_kwargs
=
self
.
media_io_kwargs
,
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
mm_processor_cache_gb
=
self
.
mm_processor_cache_gb
,
mm_encoder_tp_mode
=
self
.
mm_encoder_tp_mode
,
interleave_mm_strings
=
self
.
interleave_mm_strings
,
skip_mm_profiling
=
self
.
skip_mm_profiling
)
skip_mm_profiling
=
self
.
skip_mm_profiling
,
)
return
None
...
...
@@ -2547,6 +2563,22 @@ class MultiModalConfig:
Set to `0` to disable this cache completely (not recommended).
"""
mm_encoder_tp_mode
:
MMEncoderTPMode
=
"weights"
"""
Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP.
"""
interleave_mm_strings
:
bool
=
False
"""
Enable fully interleaved support for multimodal prompts.
...
...
vllm/config/parallel.py
View file @
5efd6905
...
...
@@ -137,10 +137,6 @@ class ParallelConfig:
rank
:
int
=
0
"""Global rank in distributed setup."""
enable_multimodal_encoder_data_parallel
:
bool
=
False
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""
@
property
def
world_size_across_dp
(
self
)
->
int
:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
...
...
vllm/engine/arg_utils.py
View file @
5efd6905
...
...
@@ -28,12 +28,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DeviceConfig
,
DistributedExecutorBackend
,
GuidedDecodingBackend
,
HfOverrides
,
KVEventsConfig
,
KVTransferConfig
,
LoadConfig
,
LogprobsMode
,
LoRAConfig
,
MambaDType
,
M
odelConfig
,
Model
DType
,
ModelImpl
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
PoolerConfig
,
PrefixCachingHashAlgo
,
RunnerOption
,
SchedulerConfig
,
SchedulerPolicy
,
SpeculativeConfig
,
TaskOption
,
TokenizerMode
,
VllmConfig
,
get_attr_docs
,
get_field
)
LoRAConfig
,
MambaDType
,
M
MEncoderTPMode
,
Model
Config
,
ModelDType
,
ModelImpl
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
PoolerConfig
,
PrefixCachingHashAlgo
,
RunnerOption
,
SchedulerConfig
,
SchedulerPolicy
,
SpeculativeConfig
,
TaskOption
,
TokenizerMode
,
VllmConfig
,
get_attr_docs
,
get_field
)
from
vllm.logger
import
init_logger
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.plugins
import
load_general_plugins
...
...
@@ -352,6 +352,7 @@ class EngineArgs:
MultiModalConfig
.
mm_processor_kwargs
disable_mm_preprocessor_cache
:
bool
=
False
# DEPRECATED
mm_processor_cache_gb
:
int
=
MultiModalConfig
.
mm_processor_cache_gb
mm_encoder_tp_mode
:
MMEncoderTPMode
=
MultiModalConfig
.
mm_encoder_tp_mode
skip_mm_profiling
:
bool
=
MultiModalConfig
.
skip_mm_profiling
# LoRA fields
enable_lora
:
bool
=
False
...
...
@@ -434,16 +435,14 @@ class EngineArgs:
use_tqdm_on_load
:
bool
=
LoadConfig
.
use_tqdm_on_load
pt_load_map_location
:
str
=
LoadConfig
.
pt_load_map_location
enable_multimodal_encoder_data_parallel
:
bool
=
\
ParallelConfig
.
enable_multimodal_encoder_data_parallel
# DEPRECATED
enable_multimodal_encoder_data_parallel
:
bool
=
False
logits_processors
:
Optional
[
list
[
Union
[
str
,
type
[
LogitsProcessor
]]]]
=
ModelConfig
.
logits_processors
"""Custom logitproc types"""
async_scheduling
:
bool
=
SchedulerConfig
.
async_scheduling
# DEPRECATED
enable_prompt_adapter
:
bool
=
False
kv_sharing_fast_prefill
:
bool
=
\
CacheConfig
.
kv_sharing_fast_prefill
...
...
@@ -685,7 +684,8 @@ class EngineArgs:
**
parallel_kwargs
[
"worker_extension_cls"
])
parallel_group
.
add_argument
(
"--enable-multimodal-encoder-data-parallel"
,
**
parallel_kwargs
[
"enable_multimodal_encoder_data_parallel"
])
action
=
"store_true"
,
deprecated
=
True
)
# KV cache arguments
cache_kwargs
=
get_kwargs
(
CacheConfig
)
...
...
@@ -735,6 +735,8 @@ class EngineArgs:
multimodal_group
.
add_argument
(
"--disable-mm-preprocessor-cache"
,
action
=
"store_true"
,
deprecated
=
True
)
multimodal_group
.
add_argument
(
"--mm-encoder-tp-mode"
,
**
multimodal_kwargs
[
"mm_encoder_tp_mode"
])
multimodal_group
.
add_argument
(
"--interleave-mm-strings"
,
**
multimodal_kwargs
[
"interleave_mm_strings"
])
...
...
@@ -909,6 +911,14 @@ class EngineArgs:
self
.
mm_processor_cache_gb
=
envs
.
VLLM_MM_INPUT_CACHE_GIB
if
self
.
enable_multimodal_encoder_data_parallel
:
logger
.
warning
(
"--enable-multimodal-encoder-data-parallel` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-encoder-tp-mode data` instead."
)
self
.
mm_encoder_tp_mode
=
"data"
return
ModelConfig
(
model
=
self
.
model
,
hf_config_path
=
self
.
hf_config_path
,
...
...
@@ -947,6 +957,7 @@ class EngineArgs:
config_format
=
self
.
config_format
,
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
mm_processor_cache_gb
=
self
.
mm_processor_cache_gb
,
mm_encoder_tp_mode
=
self
.
mm_encoder_tp_mode
,
override_neuron_config
=
self
.
override_neuron_config
,
override_pooler_config
=
self
.
override_pooler_config
,
logits_processor_pattern
=
self
.
logits_processor_pattern
,
...
...
@@ -1258,8 +1269,6 @@ class EngineArgs:
distributed_executor_backend
=
self
.
distributed_executor_backend
,
worker_cls
=
self
.
worker_cls
,
worker_extension_cls
=
self
.
worker_extension_cls
,
enable_multimodal_encoder_data_parallel
=
self
.
enable_multimodal_encoder_data_parallel
,
)
if
model_config
.
is_multimodal_model
:
...
...
vllm/model_executor/models/mllama4.py
View file @
5efd6905
...
...
@@ -728,8 +728,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
use_data_parallel
=
(
vllm_config
.
parallel_config
.
enable_multimodal_encoder_data_parallel
)
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
multimodal_config
=
multimodal_config
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
5efd6905
...
...
@@ -877,8 +877,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config
:
Qwen2_5_VLConfig
=
vllm_config
.
model_config
.
hf_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
use_data_parallel
=
(
vllm_config
.
parallel_config
.
enable_multimodal_encoder_data_parallel
)
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
...
...
vllm/model_executor/models/step3_vl.py
View file @
5efd6905
...
...
@@ -882,8 +882,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
use_data_parallel
=
(
vllm_config
.
parallel_config
.
enable_multimodal_encoder_data_parallel
)
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
if
multimodal_config
.
get_limit_per_prompt
(
"image"
):
self
.
vision_model
=
Step3VisionTransformer
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment