Unverified Commit fe8d7b6f authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Interface to enable batch-level DP support (#23733)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 16dc4052
...@@ -168,8 +168,11 @@ llm = LLM( ...@@ -168,8 +168,11 @@ llm = LLM(
Batch-level DP is not to be confused with API request-level DP Batch-level DP is not to be confused with API request-level DP
(which is instead controlled by `data_parallel_size`). (which is instead controlled by `data_parallel_size`).
The availability of batch-level DP is based on model implementation. Batch-level DP needs to be implemented on a per-model basis,
Currently, the following models support `mm_encoder_tp_mode="data"`: and enabled by setting `supports_encoder_tp_data = True` in the model class.
Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to use this feature.
Known supported models:
- Llama4 (<gh-pr:18368>) - Llama4 (<gh-pr:18368>)
- MiniCPM-V-4 (<gh-pr:23327>) - MiniCPM-V-4 (<gh-pr:23327>)
......
...@@ -872,6 +872,13 @@ class ModelConfig: ...@@ -872,6 +872,13 @@ class ModelConfig:
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
if self._model_info.supports_multimodal: if self._model_info.supports_multimodal:
if (self.mm_encoder_tp_mode == "data" and
not self._model_info.supports_multimodal_encoder_tp_data):
logger.warning_once(
"This model does not support `--mm-encoder-tp-mode data`. "
"Falling back to `--mm-encoder-tp-mode weights`.")
self.mm_encoder_tp_mode = "weights"
return MultiModalConfig( return MultiModalConfig(
limit_per_prompt=self.limit_mm_per_prompt, limit_per_prompt=self.limit_mm_per_prompt,
media_io_kwargs=self.media_io_kwargs, media_io_kwargs=self.media_io_kwargs,
......
...@@ -52,6 +52,12 @@ class SupportsMultiModal(Protocol): ...@@ -52,6 +52,12 @@ class SupportsMultiModal(Protocol):
MRO of your model class. MRO of your model class.
""" """
supports_encoder_tp_data: ClassVar[bool] = False
"""
A flag that indicates whether this model supports
`multimodal_config.mm_encoder_tp_mode="data"`.
"""
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
""" """
...@@ -137,6 +143,11 @@ def supports_multimodal( ...@@ -137,6 +143,11 @@ def supports_multimodal(
return getattr(model, "supports_multimodal", False) return getattr(model, "supports_multimodal", False)
def supports_multimodal_encoder_tp_data(
model: Union[type[object], object]) -> bool:
return getattr(model, "supports_encoder_tp_data", False)
@runtime_checkable @runtime_checkable
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
"""The interface required for all multi-modal models.""" """The interface required for all multi-modal models."""
......
...@@ -1521,6 +1521,8 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1521,6 +1521,8 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
], ],
} }
supports_encoder_tp_data = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (4, 0) assert self.version == (4, 0)
......
...@@ -716,6 +716,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -716,6 +716,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
} }
supports_encoder_tp_data = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"): if modality.startswith("image"):
......
...@@ -868,6 +868,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -868,6 +868,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"model.": "language_model.model.", "model.": "language_model.model.",
}) })
supports_encoder_tp_data = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"): if modality.startswith("image"):
......
...@@ -27,8 +27,10 @@ from vllm.transformers_utils.dynamic_module import ( ...@@ -27,8 +27,10 @@ from vllm.transformers_utils.dynamic_module import (
from .interfaces import (has_inner_state, has_noops, is_attention_free, from .interfaces import (has_inner_state, has_noops, is_attention_free,
is_hybrid, supports_cross_encoding, is_hybrid, supports_cross_encoding,
supports_multimodal, supports_multimodal_raw_input, supports_multimodal,
supports_pp, supports_transcription, supports_v0_only) supports_multimodal_encoder_tp_data,
supports_multimodal_raw_input, supports_pp,
supports_transcription, supports_v0_only)
from .interfaces_base import (get_default_pooling_type, is_pooling_model, from .interfaces_base import (get_default_pooling_type, is_pooling_model,
is_text_generation_model) is_text_generation_model)
...@@ -324,6 +326,7 @@ class _ModelInfo: ...@@ -324,6 +326,7 @@ class _ModelInfo:
supports_cross_encoding: bool supports_cross_encoding: bool
supports_multimodal: bool supports_multimodal: bool
supports_multimodal_raw_input: bool supports_multimodal_raw_input: bool
supports_multimodal_encoder_tp_data: bool
supports_pp: bool supports_pp: bool
has_inner_state: bool has_inner_state: bool
is_attention_free: bool is_attention_free: bool
...@@ -343,6 +346,8 @@ class _ModelInfo: ...@@ -343,6 +346,8 @@ class _ModelInfo:
supports_cross_encoding=supports_cross_encoding(model), supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model), supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input=supports_multimodal_raw_input(model), supports_multimodal_raw_input=supports_multimodal_raw_input(model),
supports_multimodal_encoder_tp_data=
supports_multimodal_encoder_tp_data(model),
supports_pp=supports_pp(model), supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model), has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model), is_attention_free=is_attention_free(model),
......
...@@ -867,6 +867,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -867,6 +867,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"lm_head.": "language_model.lm_head.", "lm_head.": "language_model.lm_head.",
}) })
supports_encoder_tp_data = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"): if modality.startswith("image"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment