Commit b12c902b authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'cx/v0.11.0-dev' into v0.11.0-dev-omni

parents c16e075a f39afa4a
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, MutableSequence from collections.abc import Iterable, Mapping, MutableSequence, Callable
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable) Union, overload, runtime_checkable)
...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.utils import supports_kw from vllm.utils import supports_kw
from .interfaces_base import is_pooling_model from .interfaces_base import is_pooling_model, VllmModel
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -81,10 +81,9 @@ class SupportsMultiModal(Protocol): ...@@ -81,10 +81,9 @@ class SupportsMultiModal(Protocol):
""" """
... ...
def get_multimodal_embeddings(self, def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
**kwargs: object) -> MultiModalEmbeddings:
""" """
Returns multimodal embeddings generated from multimodal kwargs Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings. to be merged with text embeddings.
Note: Note:
...@@ -94,11 +93,11 @@ class SupportsMultiModal(Protocol): ...@@ -94,11 +93,11 @@ class SupportsMultiModal(Protocol):
""" """
... ...
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> VllmModel:
""" """
Returns the underlying language model used for text generation. Returns the underlying language model used for text generation.
This is typically the `torch.nn.Module` instance responsible for This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states processing the merged multimodal embeddings and producing hidden states
Returns: Returns:
...@@ -106,19 +105,83 @@ class SupportsMultiModal(Protocol): ...@@ -106,19 +105,83 @@ class SupportsMultiModal(Protocol):
""" """
... ...
@overload
def get_input_embeddings(self, input_ids: Tensor) -> Tensor: ...
@overload
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: MultiModalEmbeddings,
*,
is_multimodal: torch.Tensor,
handle_oov_mm_token: bool = False,
) -> Tensor: ...
def _get_text_embeddings(
self,
input_ids: Tensor,
get_input_embeddings: Callable[[Tensor], Tensor],
*,
is_multimodal: Optional[Tensor],
handle_oov_mm_token: bool,
) -> Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = get_input_embeddings(input_ids[is_text])
return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
return get_input_embeddings(input_ids)
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: Tensor, input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[Tensor] = None,
handle_oov_mm_token: bool = False,
) -> Tensor: ) -> Tensor:
""" """
Returns the input embeddings merged from the text embeddings from Apply token embeddings to `input_ids`.
input_ids and the multimodal embeddings generated from multimodal
kwargs. If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
""" """
... from .utils import _merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings(
input_ids,
self.get_language_model().get_input_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
@runtime_checkable @runtime_checkable
class SupportsMultiModalPruning(Protocol): class SupportsMultiModalPruning(Protocol):
"""The interface required for models that support returning both input """The interface required for models that support returning both input
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen team. # Copyright 2025 The Qwen team.
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part).""" """Inference-only Qwen3-Omni-Moe model (thinker part)."""
import os import os
import math import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
...@@ -48,7 +49,9 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import ( ...@@ -48,7 +49,9 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
) )
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention.backends.registry import _Backend
# from vllm.attention.backends.registry import _Backend
from vllm.platforms import _Backend, current_platform
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -106,6 +109,7 @@ from .utils import ( ...@@ -106,6 +109,7 @@ from .utils import (
_merge_multimodal_embeddings, _merge_multimodal_embeddings,
maybe_prefix, maybe_prefix,
) )
from .vision import ( from .vision import (
conv3d_to_linear_weight, conv3d_to_linear_weight,
get_llm_pos_ids_for_vision, get_llm_pos_ids_for_vision,
...@@ -143,18 +147,28 @@ class Qwen3_VisionPatchEmbed(nn.Module): ...@@ -143,18 +147,28 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size), # self.proj = ReplicatedLinear(
# in_channels * math.prod(kernel_size),
# hidden_size,
# bias=True,
# return_bias=False,
# )
self.proj = nn.Conv3d(
in_channels,
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True, bias=True,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape L, C = x.shape
if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1': x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = x.to(memory_format=torch.channels_last_3d) # if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1':
x = self.proj(x) # x = x.to(memory_format=torch.channels_last_3d)
x = self.proj(x).view(L, self.hidden_size)
return x return x
...@@ -308,7 +322,6 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -308,7 +322,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = vision_config.hidden_size self.hidden_size = vision_config.hidden_size
...@@ -380,9 +393,7 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -380,9 +393,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
) )
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim, dtype=torch.get_default_dtype()
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
) )
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype() torch.get_default_dtype()
...@@ -571,8 +582,8 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -571,8 +582,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"): # if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight) # loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
...@@ -811,7 +822,8 @@ class Qwen3OmniMoeThinkerMultiModalProcessor( ...@@ -811,7 +822,8 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if is_update_applied: if is_update_applied:
prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video) prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video)
( (
prompt_ids, prompt_ids,
prompt,
mm_placeholders, mm_placeholders,
) = self._apply_prompt_updates( ) = self._apply_prompt_updates(
prompt_ids, prompt_ids,
...@@ -829,7 +841,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor( ...@@ -829,7 +841,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_item_counts, mm_item_counts,
) )
else: else:
prompt_ids, mm_placeholders = self._apply_prompt_updates( prompt_ids, prompt, mm_placeholders = self._apply_prompt_updates(
prompt_ids, prompt_ids,
mm_prompt_updates, mm_prompt_updates,
) )
...@@ -837,8 +849,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor( ...@@ -837,8 +849,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_placeholders, mm_placeholders,
mm_item_counts, mm_item_counts,
) )
return prompt_ids, prompt, mm_placeholders
return prompt_ids, mm_placeholders
def get_updates_use_audio_in_video( def get_updates_use_audio_in_video(
self, self,
...@@ -1160,18 +1171,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1160,18 +1171,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
) )
self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config) self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3Omni_VisionTransformer( self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config, vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
) )
self.quant_config = quant_config self.quant_config = quant_config
...@@ -1375,7 +1379,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1375,7 +1379,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
.contiguous() .contiguous()
) )
self._set_deepstack_input_embeds(deepstack_input_embeds) self._set_deepstack_input_embeds(deepstack_input_embeds)
inputs_embeds = _merge_multimodal_embeddings( inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
...@@ -1434,7 +1437,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1434,7 +1437,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loaded_weights return loaded_weights
@classmethod
def get_mrope_input_positions( def get_mrope_input_positions(
self, self,
input_tokens: list[int], input_tokens: list[int],
......
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.func import functional_call from torch.func import functional_call
from transformers import PretrainedConfig from transformers import PretrainedConfig
from typing_extensions import deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -391,92 +392,79 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: ...@@ -391,92 +392,79 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
return " + ".join( return " + ".join(
_embedding_count_expression(inner) for inner in embeddings) _embedding_count_expression(inner) for inner in embeddings)
def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]:
ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)]
for num in lst:
index = num // interval
ranges[index].append(num)
return ranges
def _merge_multimodal_embeddings( def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings: NestedTensors, multimodal_embeddings: NestedTensors,
is_multimodal: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in positions in `inputs_embeds` corresponding to placeholder tokens in
``input_ids``. `input_ids`.
Note: Note:
This updates ``inputs_embeds`` in place. This updates `inputs_embeds` in place.
""" """
flattened = _flatten_embeddings(multimodal_embeddings) if len(multimodal_embeddings) == 0:
return inputs_embeds
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
try: try:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened. # For debugging
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
flattened.to(dtype=inputs_embeds.dtype))
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(
is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
)
except RuntimeError as e: except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item() num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if flattened.shape[0] != num_expected_tokens: if num_actual_tokens != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings) expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError( raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} " f"Attempted to assign {expr} = {num_actual_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders" f"multimodal tokens to {num_expected_tokens} placeholders"
) from e ) from e
else:
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
def embed_multimodal(
input_ids: torch.Tensor,
multimodal_token_id: int,
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
multimodal_embeds: NestedTensors,
) -> torch.Tensor:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal = input_ids == multimodal_token_id
is_text = ~is_multimodal
text_embeds = get_text_embeds(input_ids[is_text])
merged_embeds = torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
)
merged_embeds[is_text] = text_embeds raise ValueError("Error during masked scatter operation") from e
return _merge_multimodal_embeddings( return inputs_embeds
merged_embeds,
is_multimodal,
multimodal_embeds,
)
@deprecated(
"`merge_multimodal_embeddings` has been replaced with "
"`SupportsMultiModal.get_input_embeddings` and will be "
"removed in v0.12."
)
def merge_multimodal_embeddings( def merge_multimodal_embeddings(
input_ids: torch.Tensor, input_ids: torch.Tensor,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors, multimodal_embeddings: NestedTensors,
placeholder_token_id: Union[int, list[int]], placeholder_token_id: int | list[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in positions in `inputs_embeds` corresponding to placeholder tokens in
``input_ids``. `input_ids`.
``placeholder_token_id`` can be a list of token ids (e.g, token ids `placeholder_token_id` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of the order of these tokens in the `input_ids` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to their embeddings in `multimodal_embeddings` since we need to
slice-merge instead of individually scattering. slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
...@@ -491,26 +479,32 @@ def merge_multimodal_embeddings( ...@@ -491,26 +479,32 @@ def merge_multimodal_embeddings(
input_ids for a correct embedding merge. input_ids for a correct embedding merge.
Note: Note:
This updates ``inputs_embeds`` in place. This updates `inputs_embeds` in place.
""" """
if isinstance(placeholder_token_id, list): if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor( is_multimodal = isin_list(input_ids, placeholder_token_id)
placeholder_token_id, else:
pin_memory=is_pin_memory_available()).to(device=input_ids.device, is_multimodal = input_ids == placeholder_token_id
non_blocking=True)
return _merge_multimodal_embeddings(
inputs_embeds,
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)
return _merge_multimodal_embeddings( return _merge_multimodal_embeddings(
inputs_embeds, inputs_embeds,
(input_ids == placeholder_token_id), multimodal_embeddings=multimodal_embeddings,
multimodal_embeddings, is_multimodal=is_multimodal,
) )
def isin_list(
elements: torch.Tensor,
test_elements_list: list[int],
) -> torch.Tensor:
test_elements = torch.tensor(
test_elements_list,
pin_memory=is_pin_memory_available(),
).to(device=elements.device, non_blocking=True)
return torch.isin(elements, test_elements)
class LayerFn(Protocol): class LayerFn(Protocol):
def __call__(self, prefix: str) -> torch.nn.Module: def __call__(self, prefix: str) -> torch.nn.Module:
......
...@@ -368,6 +368,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -368,6 +368,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int32) dtype=torch.int32)
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64) dtype=torch.int64)
# Only relevant for multimodal models
if self.supports_mm_inputs:
self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope: if self.uses_mrope:
...@@ -1612,17 +1615,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1612,17 +1615,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
shift_computed_tokens: int = 0, shift_computed_tokens: int = 0,
) -> list[torch.Tensor]: ) -> tuple[list[torch.Tensor], torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
mm_embeds = list[torch.Tensor]()
is_mm_embed = self.is_mm_embed.cpu
is_mm_embed[:total_num_scheduled_tokens] = False
req_start_idx = 0
should_sync_mrope_positions = False should_sync_mrope_positions = False
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids: for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = [] mm_embeds_req: list[torch.Tensor] = []
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_id]
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = \ num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens
req_state.num_computed_tokens + shift_computed_tokens
for mm_feature in req_state.mm_features: for mm_feature in req_state.mm_features:
pos_info = mm_feature.mm_position pos_info = mm_feature.mm_position
start_pos = pos_info.offset start_pos = pos_info.offset
...@@ -1649,12 +1658,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1649,12 +1658,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_hash = mm_feature.identifier mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None) encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\ assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None: if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx] is_embed = is_embed[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True if is_embed is None else is_embed
)
mm_embeds_item = gather_mm_placeholders( mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx], encoder_output[start_idx:end_idx],
is_embed=is_embed, is_embed=is_embed,
...@@ -1662,6 +1675,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1662,6 +1675,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds_req.append(mm_embeds_item) mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope: if self.is_multimodal_pruning_enabled and self.uses_mrope:
assert req_state.mrope_positions is not None
should_sync_mrope_positions = True should_sync_mrope_positions = True
mm_embeds_req, new_mrope_positions, new_delta = ( mm_embeds_req, new_mrope_positions, new_delta = (
self.model.recompute_mrope_positions( self.model.recompute_mrope_positions(
...@@ -1669,19 +1683,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1669,19 +1683,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
multimodal_embeddings=mm_embeds_req, multimodal_embeddings=mm_embeds_req,
mrope_positions=req_state.mrope_positions, mrope_positions=req_state.mrope_positions,
num_computed_tokens=req_state.num_computed_tokens, num_computed_tokens=req_state.num_computed_tokens,
)) )
assert req_state.mrope_positions is not None )
req_state.mrope_positions.copy_(new_mrope_positions) req_state.mrope_positions.copy_(new_mrope_positions)
req_state.mrope_position_delta = new_delta req_state.mrope_position_delta = new_delta
mm_embeds.extend(mm_embeds_req) mm_embeds.extend(mm_embeds_req)
req_start_idx += num_scheduled_tokens
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
if should_sync_mrope_positions: if should_sync_mrope_positions:
self._calc_mrope_positions(scheduler_output) self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu( self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
scheduler_output.total_num_scheduled_tokens)
return mm_embeds return mm_embeds, is_mm_embed
def _extract_encoder_inputs( def _extract_encoder_inputs(
self, self,
...@@ -1975,7 +1991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1975,7 +1991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
and not self.model_config.is_encoder_decoder): and not self.model_config.is_encoder_decoder):
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output) mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
...@@ -1983,6 +1999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1983,6 +1999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
inputs_embeds_scheduled = self.model.get_input_embeddings( inputs_embeds_scheduled = self.model.get_input_embeddings(
input_ids=self.input_ids.gpu[:num_scheduled_tokens], input_ids=self.input_ids.gpu[:num_scheduled_tokens],
multimodal_embeddings=mm_embeds or None, multimodal_embeddings=mm_embeds or None,
is_multimodal=is_mm_embed,
) )
# TODO(woosuk): Avoid the copy. Optimize. # TODO(woosuk): Avoid the copy. Optimize.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try: try:
from ._version import __version__, __version_tuple__ __version__ = "0.11.0"
__version_tuple__ = (0, 11, 0)
__hcu_version__ = f'0.11.0+das.opt1.alpha.6c015e7.dtk25041'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e: except Exception as e:
import warnings import warnings
warnings.warn(f"Failed to read commit hash:\n{e}", warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning, RuntimeWarning,
stacklevel=2) stacklevel=2)
__version__ = "dev" __version__ = "dev"
__version_tuple__ = (0, 0, __version__) __version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str): def _prev_minor_version_was(version_str):
"""Check whether a given version matches the previous minor version. '''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version. Return True if version_str matches the previous minor version.
...@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str): ...@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'. supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version. Used for --show-hidden-metrics-for-version.
""" '''
# Match anything if this is a dev tree # Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0): if __version_tuple__[0:2] == (0, 0):
return True return True
# Note - this won't do the right thing when we release 1.0! # Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0 # assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int) assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}" return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version(): def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number.""" '''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine" # In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int) assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}" return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment