Unverified Commit 0f6d7a9a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Models] Add remaining model PP support (#7168)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: default avatarMurali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 303d4479
This diff is collapsed.
...@@ -55,7 +55,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -55,7 +55,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalInputs) MultiModalInputs)
...@@ -68,6 +67,7 @@ from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, ...@@ -68,6 +67,7 @@ from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
from vllm.transformers_utils.processor import get_processor from vllm.transformers_utils.processor import get_processor
from vllm.utils import is_cpu from vllm.utils import is_cpu
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory) make_empty_intermediate_tensors_factory)
...@@ -883,7 +883,8 @@ def input_processor_for_qwen2_vl(ctx: InputContext, ...@@ -883,7 +883,8 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
"video", get_max_qwen2_vl_video_tokens) "video", get_max_qwen2_vl_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, def __init__(self,
config: Qwen2VLConfig, config: Qwen2VLConfig,
...@@ -1027,7 +1028,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1027,7 +1028,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for Qwen2-VL. """Run forward pass for Qwen2-VL.
Args: Args:
...@@ -1047,16 +1048,18 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1047,16 +1048,18 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed. `None` if no videos are passed.
""" """
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs)
if (image_input is None if image_input is None and video_input is None:
and video_input is None) or not get_pp_group().is_first_rank:
inputs_embeds = None inputs_embeds = None
else: else:
if getattr(self.config, "rope_scaling", {}).get("type", rope_scaling = getattr(self.config, "rope_scaling", {})
None) == "mrope": if rope_scaling.get("type", None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}") f"(3, seq_len) positions, but got {positions.size()}")
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment