Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
import re
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
...@@ -12,6 +11,7 @@ from functools import lru_cache ...@@ -12,6 +11,7 @@ from functools import lru_cache
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast) TypeVar, Union, cast)
import regex as re
import torch import torch
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -114,13 +114,14 @@ class PromptUpdateDetails(Generic[_S]): ...@@ -114,13 +114,14 @@ class PromptUpdateDetails(Generic[_S]):
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
""" """
Given {attr}`full`, return a boolean mask of shape `(len(full),)` Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
indicating which positions of `full` to assign embeddings to. return a boolean mask of shape `(len(full),)` indicating which positions
of `full` to assign embeddings to.
`None` (default) means to assign embeddings to all positions of `full`. `None` (default) means to assign embeddings to all positions of `full`.
The embeddings are obtained by calling The embeddings are obtained by calling
{class}`SupportsMultiModal.get_multimodal_embeddings`. [`SupportsMultiModal.get_multimodal_embeddings`][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings].
""" """
@staticmethod @staticmethod
...@@ -159,13 +160,15 @@ PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] ...@@ -159,13 +160,15 @@ PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
The token sequence or text that are part of the update. The token sequence or text that are part of the update.
If only part of the content corresponds to feature placeholders, you can If only part of the content corresponds to feature placeholders, you can
use {class}`PromptUpdateDetails` to specify which part. use [`PromptUpdateDetails`][vllm.multimodal.processing.PromptUpdateDetails] to
specify which part.
""" """
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo], PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
PromptUpdateInfo] PromptUpdateInfo]
""" """
Given the index of the processed item within {attr}`modality`, Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the corresponding token sequence (or text). output the corresponding token sequence (or text).
For convenience, you can directly pass in the token sequence (or text) For convenience, you can directly pass in the token sequence (or text)
...@@ -260,8 +263,10 @@ class PromptInsertion(PromptUpdate): ...@@ -260,8 +263,10 @@ class PromptInsertion(PromptUpdate):
insertion: PromptUpdateContent = field(repr=False) insertion: PromptUpdateContent = field(repr=False)
""" """
Given the index of the processed item within {attr}`modality`, Given the index of the processed item within
output the token sequence (or text) to insert right after {attr}`target`. [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the token sequence (or text) to insert right after
[`target`][vllm.multimodal.processing.PromptUpdate.target].
For convenience, you can directly pass in the token sequence (or text) For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input. instead of a function if it does not depend on the input.
...@@ -332,8 +337,10 @@ class PromptReplacement(PromptUpdate): ...@@ -332,8 +337,10 @@ class PromptReplacement(PromptUpdate):
replacement: PromptUpdateContent = field(repr=False) replacement: PromptUpdateContent = field(repr=False)
""" """
Given the index of the processed item within {attr}`modality`, Given the index of the processed item within
output the token sequence (or text) to replace {attr}`target`. [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the token sequence (or text) to replace
[`target`][vllm.multimodal.processing.PromptUpdate.target].
For convenience, you can directly pass in the token sequence (or text) For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input. instead of a function if it does not depend on the input.
...@@ -387,14 +394,16 @@ _M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp]) ...@@ -387,14 +394,16 @@ _M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
"""Convenience function to apply {func}`full_groupby` based on modality.""" """Convenience function to apply [`full_groupby`][vllm.utils.full_groupby]
based on modality."""
return full_groupby(values, key=lambda x: x.modality) return full_groupby(values, key=lambda x: x.modality)
@dataclass @dataclass
class _BoundPromptSequence: class _BoundPromptSequence:
""" """
A {data}`_PromptSeq` bound to a tokenizer to automatically A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound
to a tokenizer to automatically
convert between token sequence and text representations. convert between token sequence and text representations.
""" """
tokenizer: AnyTokenizer = field(repr=False) tokenizer: AnyTokenizer = field(repr=False)
...@@ -446,9 +455,11 @@ class _BoundPromptContent: ...@@ -446,9 +455,11 @@ class _BoundPromptContent:
@dataclass @dataclass
class BoundPromptUpdate: class BoundPromptUpdate:
""" """
A {class}`PromptUpdate` bound to a tokenizer to automatically convert A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] bound
{attr}`target` and the result of {meth}`get_content` between to a tokenizer to automatically convert
token sequence and text representations. [`target`][vllm.multimodal.processing.PromptUpdate.target] and the result of
[`get_content`][vllm.multimodal.processing.BoundPromptUpdate.get_content]
between token sequence and text representations.
""" """
_origin: PromptUpdate _origin: PromptUpdate
tokenizer: AnyTokenizer = field(repr=False) tokenizer: AnyTokenizer = field(repr=False)
...@@ -482,7 +493,8 @@ class BoundPromptUpdate: ...@@ -482,7 +493,8 @@ class BoundPromptUpdate:
def get_content(self, item_idx: int) -> _BoundPromptContent: def get_content(self, item_idx: int) -> _BoundPromptContent:
""" """
Given the index of the processed item within {attr}`modality`, Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the token sequence (or text) to update. output the token sequence (or text) to update.
""" """
content = self.content content = self.content
...@@ -1019,7 +1031,8 @@ class ProcessingCache: ...@@ -1019,7 +1031,8 @@ class ProcessingCache:
) -> None: ) -> None:
""" """
Put a processed multi-modal item into the cache Put a processed multi-modal item into the cache
according to its dependencies (see {meth}`get`). according to its dependencies
(see [`get`][vllm.multimodal.processing.ProcessingCache.get]).
""" """
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item}, **{modality: input_item},
...@@ -1091,7 +1104,8 @@ _I = TypeVar("_I", bound=BaseProcessingInfo) ...@@ -1091,7 +1104,8 @@ _I = TypeVar("_I", bound=BaseProcessingInfo)
MultiModalHashes = dict[str, list[str]] MultiModalHashes = dict[str, list[str]]
""" """
A collection of hashes with a similar structure as {class}`MultiModalKwargs`. A collection of hashes with a similar structure as
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
""" """
...@@ -1099,7 +1113,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1099,7 +1113,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
""" """
Abstract base class to process multi-modal inputs to be used in vLLM. Abstract base class to process multi-modal inputs to be used in vLLM.
Not to be confused with {class}`transformers.ProcessorMixin`. Not to be confused with `transformers.ProcessorMixin`.
""" """
def __init__(self, def __init__(self,
...@@ -1126,10 +1140,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1126,10 +1140,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
""" """
Construct a parser to preprocess multi-modal data items Construct a parser to preprocess multi-modal data items
before passing them to {meth}`_get_hf_mm_data`. before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
You can support additional modalities by creating a subclass You can support additional modalities by creating a subclass
of {class}`MultiModalDataParser` that has additional subparsers. of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
that has additional subparsers.
""" """
return MultiModalDataParser() return MultiModalDataParser()
...@@ -1138,8 +1154,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1138,8 +1154,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
) -> MultiModalDataItems: ) -> MultiModalDataItems:
""" """
Normalize {class}`MultiModalDataDict` to {class}`MultiModalDataItems` Normalize
before passing them to {meth}`_get_hf_mm_data`. [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
""" """
mm_items = self.data_parser.parse_mm_data(mm_data) mm_items = self.data_parser.parse_mm_data(mm_data)
supported_mm_limits = self.info.get_supported_mm_limits() supported_mm_limits = self.info.get_supported_mm_limits()
...@@ -1191,7 +1210,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1191,7 +1210,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
inputs. inputs.
Moreover, this information is critical to determine the token positions Moreover, this information is critical to determine the token positions
in order to construct {class}`~vllm-multimodal.input.PlaceholderRange` in order to construct
[`PlaceholderRange`][vllm.multimodal.inputs.PlaceholderRange]
for each multi-modal item. for each multi-modal item.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -1315,7 +1335,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1315,7 +1335,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Most HF processors accept prompt text but not prompt tokens. Most HF processors accept prompt text but not prompt tokens.
If the HF processor adds or removes tokens that are not related to If the HF processor adds or removes tokens that are not related to
multi-modal data, you should override this method so it is consistent multi-modal data, you should override this method so it is consistent
with the output of {meth}`_apply_hf_processor_text_only` on the with the output of
[`_apply_hf_processor_text_only`][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_text_only]
on the
corresponding text. corresponding text.
""" """
return prompt_tokens return prompt_tokens
...@@ -1330,7 +1352,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1330,7 +1352,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Since HF processor requires that text and multi-modal items Since HF processor requires that text and multi-modal items
correspond to each other, we generate dummy text using correspond to each other, we generate dummy text using
{class}`DummyInputsBuilder` to go along with the multi-modal data. [`DummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
to go along with the multi-modal data.
""" """
mm_counts = mm_items.get_all_counts() mm_counts = mm_items.get_all_counts()
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import ABC from abc import ABC
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Generic, NamedTuple, Optional, TypeVar, cast from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -25,9 +25,9 @@ logger = init_logger(__name__) ...@@ -25,9 +25,9 @@ logger = init_logger(__name__)
class ProcessorInputs: class ProcessorInputs:
""" """
Represents the keyword arguments to Represents the keyword arguments to
{meth}`vllm.multimodal.processing.BaseMultiModalProcessor.apply`. [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
""" """
prompt_text: str prompt: Union[str, list[int]]
mm_data: MultiModalDataDict mm_data: MultiModalDataDict
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
...@@ -75,7 +75,12 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -75,7 +75,12 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
"in an upcoming release.") "in an upcoming release.")
seq_len = self.info.ctx.model_config.max_model_len seq_len = self.info.ctx.model_config.max_model_len
return self.get_dummy_processor_inputs(seq_len, mm_counts).prompt_text
prompt = self.get_dummy_processor_inputs(seq_len, mm_counts).prompt
if not isinstance(prompt, str):
prompt = self.info.get_tokenizer().decode(prompt)
return prompt
# TODO: @abstractmethod after transition # TODO: @abstractmethod after transition
def get_dummy_mm_data( def get_dummy_mm_data(
...@@ -101,7 +106,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -101,7 +106,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
return ProcessorInputs(prompt_text=dummy_text, mm_data=dummy_mm_data) return ProcessorInputs(prompt=dummy_text, mm_data=dummy_mm_data)
def _get_dummy_audios( def _get_dummy_audios(
self, self,
...@@ -177,7 +182,7 @@ class MultiModalProfiler(Generic[_I]): ...@@ -177,7 +182,7 @@ class MultiModalProfiler(Generic[_I]):
seq_len, mm_counts) seq_len, mm_counts)
return self.processor.apply( return self.processor.apply(
prompt=processor_inputs.prompt_text, prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data, mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
) )
......
...@@ -29,7 +29,11 @@ _I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True) ...@@ -29,7 +29,11 @@ _I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
class ProcessingInfoFactory(Protocol[_I_co]): class ProcessingInfoFactory(Protocol[_I_co]):
"""Constructs a {class}`MultiModalProcessor` instance from the context.""" """
Constructs a
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
instance from the context.
"""
def __call__( def __call__(
self, self,
...@@ -40,7 +44,9 @@ class ProcessingInfoFactory(Protocol[_I_co]): ...@@ -40,7 +44,9 @@ class ProcessingInfoFactory(Protocol[_I_co]):
class DummyInputsBuilderFactory(Protocol[_I]): class DummyInputsBuilderFactory(Protocol[_I]):
""" """
Constructs a {class}`BaseDummyInputsBuilder` instance from the context. Constructs a
[`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
instance from the context.
""" """
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
...@@ -48,7 +54,11 @@ class DummyInputsBuilderFactory(Protocol[_I]): ...@@ -48,7 +54,11 @@ class DummyInputsBuilderFactory(Protocol[_I]):
class MultiModalProcessorFactory(Protocol[_I]): class MultiModalProcessorFactory(Protocol[_I]):
"""Constructs a {class}`MultiModalProcessor` instance from the context.""" """
Constructs a
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
instance from the context.
"""
def __call__( def __call__(
self, self,
...@@ -155,8 +165,6 @@ class MultiModalRegistry: ...@@ -155,8 +165,6 @@ class MultiModalRegistry:
""" """
Get the maximum number of tokens from each modality Get the maximum number of tokens from each modality
for profiling the memory usage of a model. for profiling the memory usage of a model.
See {meth}`MultiModalPlugin.get_max_multimodal_tokens` for more details.
""" """
mm_limits = self.get_mm_limits_per_prompt(model_config) mm_limits = self.get_mm_limits_per_prompt(model_config)
...@@ -170,8 +178,6 @@ class MultiModalRegistry: ...@@ -170,8 +178,6 @@ class MultiModalRegistry:
""" """
Get the maximum number of multi-modal tokens Get the maximum number of multi-modal tokens
for profiling the memory usage of a model. for profiling the memory usage of a model.
See {meth}`MultiModalPlugin.get_max_multimodal_tokens` for more details.
""" """
return sum(self.get_max_tokens_by_modality(model_config).values()) return sum(self.get_max_tokens_by_modality(model_config).values())
...@@ -213,10 +219,6 @@ class MultiModalRegistry: ...@@ -213,10 +219,6 @@ class MultiModalRegistry:
When the model receives multi-modal data, the provided function is When the model receives multi-modal data, the provided function is
invoked to transform the data into a dictionary of model inputs. invoked to transform the data into a dictionary of model inputs.
:::{seealso}
{ref}`mm-processing`
:::
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
...@@ -259,10 +261,6 @@ class MultiModalRegistry: ...@@ -259,10 +261,6 @@ class MultiModalRegistry:
) -> BaseMultiModalProcessor[BaseProcessingInfo]: ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
""" """
Create a multi-modal processor for a specific model and tokenizer. Create a multi-modal processor for a specific model and tokenizer.
:::{seealso}
{ref}`mm-processing`
:::
""" """
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model") raise ValueError(f"{model_config.model} is not a multimodal model")
......
...@@ -259,7 +259,8 @@ class MediaConnector: ...@@ -259,7 +259,8 @@ class MediaConnector:
global_media_connector = MediaConnector() global_media_connector = MediaConnector()
"""The global {class}`MediaConnector` instance used by vLLM.""" """The global [`MediaConnector`][vllm.multimodal.utils.MediaConnector]
instance used by vLLM."""
fetch_audio = global_media_connector.fetch_audio fetch_audio = global_media_connector.fetch_audio
fetch_image = global_media_connector.fetch_image fetch_image = global_media_connector.fetch_image
......
...@@ -164,7 +164,7 @@ class VideoMediaIO(MediaIO[npt.NDArray]): ...@@ -164,7 +164,7 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
) )
return np.stack([ return np.stack([
np.array(load_frame(frame_data)) np.asarray(load_frame(frame_data))
for frame_data in data.split(",") for frame_data in data.split(",")
]) ])
......
...@@ -9,12 +9,15 @@ from typing import Any, Generic, Optional, Union ...@@ -9,12 +9,15 @@ from typing import Any, Generic, Optional, Union
import torch import torch
from typing_extensions import TypeVar, deprecated from typing_extensions import TypeVar, deprecated
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.multimodal.inputs import MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceGroupBase, SequenceStatus) SequenceGroup, SequenceGroupBase, SequenceStatus)
logger = init_logger(__name__)
@dataclass @dataclass
class CompletionOutput: class CompletionOutput:
...@@ -122,7 +125,13 @@ class RequestOutput: ...@@ -122,7 +125,13 @@ class RequestOutput:
*, *,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
kv_transfer_params: Optional[dict[str, Any]] = None, kv_transfer_params: Optional[dict[str, Any]] = None,
# Forward compatibility, code that uses args added in new release can
# still run with older versions of vLLM without breaking.
**kwargs: Any,
) -> None: ) -> None:
if kwargs:
logger.warning_once("RequestOutput: Ignoring extra arguments: %s",
str(kwargs))
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
...@@ -382,15 +391,6 @@ class PoolingRequestOutput(Generic[_O]): ...@@ -382,15 +391,6 @@ class PoolingRequestOutput(Generic[_O]):
prompt_token_ids, finished) prompt_token_ids, finished)
def __repr__(self): def __repr__(self):
"""
Returns a string representation of an PoolingRequestOutput instance.
The representation includes the request_id and the number of outputs,
providing a quick overview of the pooling request's results.
Returns:
str: A string representation of the PoolingRequestOutput instance.
"""
return (f"{type(self).__name__}(request_id={self.request_id!r}, " return (f"{type(self).__name__}(request_id={self.request_id!r}, "
f"outputs={self.outputs!r}, " f"outputs={self.outputs!r}, "
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
......
...@@ -42,7 +42,6 @@ def tpu_platform_plugin() -> Optional[str]: ...@@ -42,7 +42,6 @@ def tpu_platform_plugin() -> Optional[str]:
logger.debug("Confirmed TPU platform is available.") logger.debug("Confirmed TPU platform is available.")
except Exception as e: except Exception as e:
logger.debug("TPU platform is not available because: %s", str(e)) logger.debug("TPU platform is not available because: %s", str(e))
pass
return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None
...@@ -112,7 +111,6 @@ def rocm_platform_plugin() -> Optional[str]: ...@@ -112,7 +111,6 @@ def rocm_platform_plugin() -> Optional[str]:
amdsmi.amdsmi_shut_down() amdsmi.amdsmi_shut_down()
except Exception as e: except Exception as e:
logger.debug("ROCm platform is not available because: %s", str(e)) logger.debug("ROCm platform is not available because: %s", str(e))
pass
return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None
...@@ -130,7 +128,6 @@ def hpu_platform_plugin() -> Optional[str]: ...@@ -130,7 +128,6 @@ def hpu_platform_plugin() -> Optional[str]:
"habana_frameworks is not found.") "habana_frameworks is not found.")
except Exception as e: except Exception as e:
logger.debug("HPU platform is not available because: %s", str(e)) logger.debug("HPU platform is not available because: %s", str(e))
pass
return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None
...@@ -148,7 +145,6 @@ def xpu_platform_plugin() -> Optional[str]: ...@@ -148,7 +145,6 @@ def xpu_platform_plugin() -> Optional[str]:
logger.debug("Confirmed XPU platform is available.") logger.debug("Confirmed XPU platform is available.")
except Exception as e: except Exception as e:
logger.debug("XPU platform is not available because: %s", str(e)) logger.debug("XPU platform is not available because: %s", str(e))
pass
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
...@@ -170,7 +166,6 @@ def cpu_platform_plugin() -> Optional[str]: ...@@ -170,7 +166,6 @@ def cpu_platform_plugin() -> Optional[str]:
except Exception as e: except Exception as e:
logger.debug("CPU platform is not available because: %s", str(e)) logger.debug("CPU platform is not available because: %s", str(e))
pass
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
......
...@@ -9,6 +9,7 @@ import psutil ...@@ -9,6 +9,7 @@ import psutil
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
...@@ -74,7 +75,7 @@ class CpuPlatform(Platform): ...@@ -74,7 +75,7 @@ class CpuPlatform(Platform):
import vllm.envs as envs import vllm.envs as envs
from vllm.utils import GiB_bytes from vllm.utils import GiB_bytes
model_config = vllm_config.model_config model_config = vllm_config.model_config
# Reminder: Please update docs/source/features/compatibility_matrix.md # Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid # If the feature combo become valid
if not model_config.enforce_eager: if not model_config.enforce_eager:
model_config.enforce_eager = True model_config.enforce_eager = True
...@@ -177,6 +178,16 @@ class CpuPlatform(Platform): ...@@ -177,6 +178,16 @@ class CpuPlatform(Platform):
" set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.") " set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.")
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod @classmethod
def is_pin_memory_available(cls) -> bool: def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on CPU.") logger.warning("Pin memory is not supported on CPU.")
......
...@@ -158,6 +158,7 @@ class CudaPlatformBase(Platform): ...@@ -158,6 +158,7 @@ class CudaPlatformBase(Platform):
"currently not supported with CUDA Graphs.") "currently not supported with CUDA Graphs.")
vllm_config.model_config.enforce_eager = True vllm_config.model_config.enforce_eager = True
compilation_config.use_cudagraph = False compilation_config.use_cudagraph = False
# FIXME: inductor breaks cudagraph (from @bnell)
compilation_config.use_inductor = False compilation_config.use_inductor = False
@classmethod @classmethod
...@@ -311,6 +312,10 @@ class CudaPlatformBase(Platform): ...@@ -311,6 +312,10 @@ class CudaPlatformBase(Platform):
def use_custom_allreduce(cls) -> bool: def use_custom_allreduce(cls) -> bool:
return True return True
@classmethod
def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
# NVML utils # NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum, _Backend from .interface import Platform, PlatformEnum, _Backend
...@@ -38,8 +39,8 @@ class HpuPlatform(Platform): ...@@ -38,8 +39,8 @@ class HpuPlatform(Platform):
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True return True
@staticmethod @classmethod
def inference_mode(): def inference_mode(cls):
return torch.no_grad() return torch.no_grad()
@classmethod @classmethod
...@@ -80,6 +81,16 @@ class HpuPlatform(Platform): ...@@ -80,6 +81,16 @@ class HpuPlatform(Platform):
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.") "VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod @classmethod
def is_pin_memory_available(cls): def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on HPU.") logger.warning("Pin memory is not supported on HPU.")
......
...@@ -84,7 +84,7 @@ class DeviceCapability(NamedTuple): ...@@ -84,7 +84,7 @@ class DeviceCapability(NamedTuple):
def to_int(self) -> int: def to_int(self) -> int:
""" """
Express device capability as an integer ``<major><minor>``. Express device capability as an integer `<major><minor>`.
It is assumed that the minor version is always a single digit. It is assumed that the minor version is always a single digit.
""" """
...@@ -157,7 +157,7 @@ class Platform: ...@@ -157,7 +157,7 @@ class Platform:
return self._enum == PlatformEnum.OOT return self._enum == PlatformEnum.OOT
def is_cuda_alike(self) -> bool: def is_cuda_alike(self) -> bool:
"""Stateless version of {func}`torch.cuda.is_available`.""" """Stateless version of [torch.cuda.is_available][]."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
def is_sleep_mode_available(self) -> bool: def is_sleep_mode_available(self) -> bool:
...@@ -194,7 +194,7 @@ class Platform: ...@@ -194,7 +194,7 @@ class Platform:
cls, cls,
device_id: int = 0, device_id: int = 0,
) -> Optional[DeviceCapability]: ) -> Optional[DeviceCapability]:
"""Stateless version of {func}`torch.cuda.get_device_capability`.""" """Stateless version of [torch.cuda.get_device_capability][]."""
return None return None
@classmethod @classmethod
...@@ -206,10 +206,11 @@ class Platform: ...@@ -206,10 +206,11 @@ class Platform:
""" """
Test whether this platform is compatible with a device capability. Test whether this platform is compatible with a device capability.
The ``capability`` argument can either be: The `capability` argument can either be:
- A tuple ``(major, minor)``. - A tuple `(major, minor)`.
- An integer ``<major><minor>``. (See {meth}`DeviceCapability.to_int`) - An integer `<major><minor>`. (See
[`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
""" """
current_capability = cls.get_device_capability(device_id=device_id) current_capability = cls.get_device_capability(device_id=device_id)
if current_capability is None: if current_capability is None:
...@@ -478,6 +479,13 @@ class Platform: ...@@ -478,6 +479,13 @@ class Platform:
""" """
raise NotImplementedError raise NotImplementedError
@classmethod
def get_piecewise_backend_cls(cls) -> str:
"""
Get piecewise backend class for piecewise graph.
"""
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED
......
...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional ...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum
...@@ -56,6 +57,16 @@ class NeuronPlatform(Platform): ...@@ -56,6 +57,16 @@ class NeuronPlatform(Platform):
vllm_config.cache_config.block_size = \ vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len # type: ignore vllm_config.model_config.max_model_len # type: ignore
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod @classmethod
def is_pin_memory_available(cls) -> bool: def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Neuron.") logger.warning("Pin memory is not supported on Neuron.")
......
...@@ -102,26 +102,42 @@ def on_mi250_mi300() -> bool: ...@@ -102,26 +102,42 @@ def on_mi250_mi300() -> bool:
@cache @cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, def use_rocm_custom_paged_attention(
block_size: int, gqa_ratio: int, qtype: torch.dtype,
max_seq_len: int, head_size: int,
sliding_window: int) -> bool: block_size: int,
gqa_ratio: int,
max_seq_len: int,
sliding_window: int,
kv_cache_dtype: str,
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# rocm custom page attention not support on gfx1*
# custom paged attn always supported on V0. On V1, requires sliding window # custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy. # disabled due to observed numerical discrepancy.
return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0 if ON_GFX9:
or sliding_window == (-1, -1)) return ((not envs.VLLM_USE_V1 or sliding_window == 0
and (qtype == torch.half or qtype == torch.bfloat16) or sliding_window == (-1, -1))
and (head_size == 64 or head_size == 128) and (qtype == torch.half or qtype == torch.bfloat16)
and (block_size == 16 or block_size == 32) and (head_size == 64 or head_size == 128)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 and (block_size == 16 or block_size == 32)
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and (gqa_ratio >= 1 and gqa_ratio <= 16)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and envs.VLLM_ROCM_USE_AITER)) and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER))
else:
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 32768 and alibi_slopes is None
and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
class RocmPlatform(Platform): class RocmPlatform(Platform):
...@@ -201,9 +217,9 @@ class RocmPlatform(Platform): ...@@ -201,9 +217,9 @@ class RocmPlatform(Platform):
major, minor = torch.cuda.get_device_capability(device_id) major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor) return DeviceCapability(major=major, minor=minor)
@staticmethod @classmethod
@with_amdsmi_context @with_amdsmi_context
def is_fully_connected(physical_device_ids: list[int]) -> bool: def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
""" """
Query if the set of gpus are fully connected by xgmi (1 hop) Query if the set of gpus are fully connected by xgmi (1 hop)
""" """
...@@ -363,3 +379,11 @@ class RocmPlatform(Platform): ...@@ -363,3 +379,11 @@ class RocmPlatform(Platform):
def get_cu_count(cls, device_id: int = 0) -> int: def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties( return torch.cuda.get_device_properties(
device_id).multi_processor_count device_id).multi_processor_count
@classmethod
def is_navi(cls) -> bool:
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
@classmethod
def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
...@@ -9,6 +9,7 @@ import vllm.envs as envs ...@@ -9,6 +9,7 @@ import vllm.envs as envs
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum, _Backend from .interface import Platform, PlatformEnum, _Backend
...@@ -161,6 +162,16 @@ class TpuPlatform(Platform): ...@@ -161,6 +162,16 @@ class TpuPlatform(Platform):
"Forcing --disable_chunked_mm_input.") "Forcing --disable_chunked_mm_input.")
scheduler_config.disable_chunked_mm_input = True scheduler_config.disable_chunked_mm_input = True
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod @classmethod
def is_pin_memory_available(cls): def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.") logger.warning("Pin memory is not supported on TPU.")
......
...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional ...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
...@@ -36,15 +37,17 @@ class XPUPlatform(Platform): ...@@ -36,15 +37,17 @@ class XPUPlatform(Platform):
logger.info("Using IPEX attention backend.") logger.info("Using IPEX attention backend.")
return "vllm.attention.backends.ipex_attn.IpexAttnBackend" return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
@staticmethod @classmethod
def get_device_capability( def get_device_capability(
device_id: int = 0) -> Optional[DeviceCapability]: cls,
device_id: int = 0,
) -> Optional[DeviceCapability]:
# capacity format differs from cuda's and will cause unexpected # capacity format differs from cuda's and will cause unexpected
# failure, so use None directly # failure, so use None directly
return None return None
@staticmethod @classmethod
def get_device_name(device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
return torch.xpu.get_device_name(device_id) return torch.xpu.get_device_name(device_id)
@classmethod @classmethod
...@@ -56,8 +59,8 @@ class XPUPlatform(Platform): ...@@ -56,8 +59,8 @@ class XPUPlatform(Platform):
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True return True
@staticmethod @classmethod
def inference_mode(): def inference_mode(cls):
return torch.no_grad() return torch.no_grad()
@classmethod @classmethod
...@@ -113,6 +116,16 @@ class XPUPlatform(Platform): ...@@ -113,6 +116,16 @@ class XPUPlatform(Platform):
parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend = "ray" parallel_config.distributed_executor_backend = "ray"
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod @classmethod
def is_pin_memory_available(cls): def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on XPU.") logger.warning("Pin memory is not supported on XPU.")
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import logging import logging
import os import os
from typing import Callable from typing import Any, Callable
import torch import torch
...@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) ...@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
plugins_loaded = False plugins_loaded = False
def load_plugins_by_group(group: str) -> dict[str, Callable]: def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]:
import sys import sys
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
from importlib_metadata import entry_points from importlib_metadata import entry_points
...@@ -27,23 +27,27 @@ def load_plugins_by_group(group: str) -> dict[str, Callable]: ...@@ -27,23 +27,27 @@ def load_plugins_by_group(group: str) -> dict[str, Callable]:
if len(discovered_plugins) == 0: if len(discovered_plugins) == 0:
logger.debug("No plugins for group %s found.", group) logger.debug("No plugins for group %s found.", group)
return {} return {}
logger.info("Available plugins for group %s:", group) logger.info("Available plugins for group %s:", group)
for plugin in discovered_plugins: for plugin in discovered_plugins:
logger.info("name=%s, value=%s", plugin.name, plugin.value) logger.info("- %s -> %s", plugin.name, plugin.value)
if allowed_plugins is None: if allowed_plugins is None:
logger.info("all available plugins for group %s will be loaded.", logger.info("All plugins in this group will be loaded. "
group) "Set `VLLM_PLUGINS` to control which plugins to load.")
logger.info("set environment variable VLLM_PLUGINS to control"
" which plugins to load.") plugins = dict[str, Callable[[], Any]]()
plugins = {}
for plugin in discovered_plugins: for plugin in discovered_plugins:
if allowed_plugins is None or plugin.name in allowed_plugins: if allowed_plugins is None or plugin.name in allowed_plugins:
if allowed_plugins is not None:
logger.info("Loading plugin %s", plugin.name)
try: try:
func = plugin.load() func = plugin.load()
plugins[plugin.name] = func plugins[plugin.name] = func
logger.info("plugin %s loaded.", plugin.name)
except Exception: except Exception:
logger.exception("Failed to load plugin %s", plugin.name) logger.exception("Failed to load plugin %s", plugin.name)
return plugins return plugins
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional, Union from typing import Optional, Union
import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
......
...@@ -27,7 +27,7 @@ VLLM_INVALID_TOKEN_ID = -1 ...@@ -27,7 +27,7 @@ VLLM_INVALID_TOKEN_ID = -1
def array_full(token_id: int, count: int): def array_full(token_id: int, count: int):
"""{class}`array` equivalent of {func}`numpy.full`.""" """[`array`][] equivalent of [numpy.full][]."""
return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
...@@ -112,12 +112,12 @@ class RequestMetrics: ...@@ -112,12 +112,12 @@ class RequestMetrics:
will include model forward, block/sync across will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time. workers, cpu-gpu sync time and sampling time.
spec_token_acceptance_counts: number of accepted speculative tokens at spec_token_acceptance_counts: number of accepted speculative tokens at
each position; the first token is from each position; the first token is from
the target model and is always accepted; the target model and is always accepted;
e.g., when it's [10, 8, 4, 2] for a req, e.g., when it's [10, 8, 4, 2] for a req,
it means there were 10 forward passes in it means there were 10 forward passes in
total, and there were 8, 4, 2 accepted total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step. tokens at 1st, 2nd, 3rd speculation step.
""" """
arrival_time: float arrival_time: float
last_token_time: float last_token_time: float
...@@ -192,8 +192,8 @@ class SequenceData(msgspec.Struct, ...@@ -192,8 +192,8 @@ class SequenceData(msgspec.Struct,
def from_prompt_token_counts( def from_prompt_token_counts(
*token_counts: tuple[int, int]) -> "SequenceData": *token_counts: tuple[int, int]) -> "SequenceData":
""" """
Construct a {class}`SequenceData` instance by concatenating Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
prompt token sequences. by concatenating prompt token sequences.
Each tuple represents one token sequence, expressed in the form Each tuple represents one token sequence, expressed in the form
`(token_id, count)`. `(token_id, count)`.
...@@ -216,8 +216,8 @@ class SequenceData(msgspec.Struct, ...@@ -216,8 +216,8 @@ class SequenceData(msgspec.Struct,
prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
) -> "SequenceData": ) -> "SequenceData":
""" """
Construct a {class}`SequenceData` instance from prompt and output Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
token sequences. from prompt and output token sequences.
""" """
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
prompt_token_ids) prompt_token_ids)
...@@ -452,9 +452,11 @@ class SequenceData(msgspec.Struct, ...@@ -452,9 +452,11 @@ class SequenceData(msgspec.Struct,
class Sequence: class Sequence:
"""Stores the data, status, and block information of a sequence. """Stores the data, status, and block information of a sequence.
The sequence is constructed from the {data}`DecoderOnlyInputs` The sequence is constructed from the
(for decoder-only) or {data}`EncoderDecoderInputs` (for encoder-decoder) [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only)
instance passed in through the `inputs` constructor argument. or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
(for encoder-decoder) instance passed in through the `inputs`
constructor argument.
Args: Args:
seq_id: The ID of the sequence. seq_id: The ID of the sequence.
...@@ -714,9 +716,9 @@ class SequenceGroup: ...@@ -714,9 +716,9 @@ class SequenceGroup:
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request. prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request. priority: User-defined priority of the request.
draft_size: The number of speculative tokens plus one from the target draft_size: The number of speculative tokens plus one from the target
model; equal to max number of tokens a step can generate model; equal to max number of tokens a step can generate
for single-draft speculative decoding but larger than for single-draft speculative decoding but larger than
that for multi-draft SD (currently not supported). that for multi-draft SD (currently not supported).
""" """
...@@ -1123,7 +1125,7 @@ class SequenceOutput( ...@@ -1123,7 +1125,7 @@ class SequenceOutput(
self.output_embed.shape if self.output_embed is not None else None self.output_embed.shape if self.output_embed is not None else None
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, " f"output_token={self.output_token}, "
f"output_embed.shape={output_embed_shape}" f"output_embed.shape={output_embed_shape}, "
f"logprobs={self.logprobs})") f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
...@@ -1494,7 +1496,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase): ...@@ -1494,7 +1496,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
for i in range(original_params.n): for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}" request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i group.seq_id_to_index[request_id_i] = i
params = copy.deepcopy(original_params) params = original_params.clone()
params.n = 1 params.n = 1
if params.seed is not None: if params.seed is not None:
params.seed += i params.seed += i
......
...@@ -294,8 +294,11 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): ...@@ -294,8 +294,11 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
inputs_embeds=None, inputs_embeds=None,
positions=model_input.input_positions, positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(
device=self.device), multi_modal_kwargs,
dtype=self.model_runner.model_config.dtype,
device=self.device,
),
**model_execute_kwargs, **model_execute_kwargs,
) )
......
...@@ -126,12 +126,12 @@ class AsyncMetricsCollector: ...@@ -126,12 +126,12 @@ class AsyncMetricsCollector:
"""Copy rejection/typical-acceptance sampling metrics """Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously. (number of accepted tokens, etc) to CPU asynchronously.
Returns a CUDA event recording when the copy is complete. Returns a device event recording when the copy is complete.
""" """
assert self._copy_stream is not None assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.cuda.current_stream()) self._copy_stream.wait_stream(current_platform.current_stream())
with torch.cuda.stream(self._copy_stream): with current_platform.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_( self._aggregate_num_accepted_tokens.copy_(
self.spec_decode_sampler.num_accepted_tokens, self.spec_decode_sampler.num_accepted_tokens,
non_blocking=True) non_blocking=True)
...@@ -142,7 +142,7 @@ class AsyncMetricsCollector: ...@@ -142,7 +142,7 @@ class AsyncMetricsCollector:
self._aggregate_num_draft_tokens = ( self._aggregate_num_draft_tokens = (
self.spec_decode_sampler.num_draft_tokens) self.spec_decode_sampler.num_draft_tokens)
aggregate_metrics_ready = torch.cuda.Event() aggregate_metrics_ready = current_platform.Event()
aggregate_metrics_ready.record(self._copy_stream) aggregate_metrics_ready.record(self._copy_stream)
return aggregate_metrics_ready return aggregate_metrics_ready
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment