Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
# SPDX-License-Identifier: Apache-2.0
import json
import re
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
......@@ -12,6 +11,7 @@ from functools import lru_cache
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast)
import regex as re
import torch
from typing_extensions import assert_never
......@@ -114,13 +114,14 @@ class PromptUpdateDetails(Generic[_S]):
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
"""
Given {attr}`full`, return a boolean mask of shape `(len(full),)`
indicating which positions of `full` to assign embeddings to.
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
return a boolean mask of shape `(len(full),)` indicating which positions
of `full` to assign embeddings to.
`None` (default) means to assign embeddings to all positions of `full`.
The embeddings are obtained by calling
{class}`SupportsMultiModal.get_multimodal_embeddings`.
[`SupportsMultiModal.get_multimodal_embeddings`][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings].
"""
@staticmethod
......@@ -159,13 +160,15 @@ PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
The token sequence or text that are part of the update.
If only part of the content corresponds to feature placeholders, you can
use {class}`PromptUpdateDetails` to specify which part.
use [`PromptUpdateDetails`][vllm.multimodal.processing.PromptUpdateDetails] to
specify which part.
"""
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
PromptUpdateInfo]
"""
Given the index of the processed item within {attr}`modality`,
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the corresponding token sequence (or text).
For convenience, you can directly pass in the token sequence (or text)
......@@ -260,8 +263,10 @@ class PromptInsertion(PromptUpdate):
insertion: PromptUpdateContent = field(repr=False)
"""
Given the index of the processed item within {attr}`modality`,
output the token sequence (or text) to insert right after {attr}`target`.
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the token sequence (or text) to insert right after
[`target`][vllm.multimodal.processing.PromptUpdate.target].
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
......@@ -332,8 +337,10 @@ class PromptReplacement(PromptUpdate):
replacement: PromptUpdateContent = field(repr=False)
"""
Given the index of the processed item within {attr}`modality`,
output the token sequence (or text) to replace {attr}`target`.
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the token sequence (or text) to replace
[`target`][vllm.multimodal.processing.PromptUpdate.target].
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
......@@ -387,14 +394,16 @@ _M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
"""Convenience function to apply {func}`full_groupby` based on modality."""
"""Convenience function to apply [`full_groupby`][vllm.utils.full_groupby]
based on modality."""
return full_groupby(values, key=lambda x: x.modality)
@dataclass
class _BoundPromptSequence:
"""
A {data}`_PromptSeq` bound to a tokenizer to automatically
A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound
to a tokenizer to automatically
convert between token sequence and text representations.
"""
tokenizer: AnyTokenizer = field(repr=False)
......@@ -446,9 +455,11 @@ class _BoundPromptContent:
@dataclass
class BoundPromptUpdate:
"""
A {class}`PromptUpdate` bound to a tokenizer to automatically convert
{attr}`target` and the result of {meth}`get_content` between
token sequence and text representations.
A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] bound
to a tokenizer to automatically convert
[`target`][vllm.multimodal.processing.PromptUpdate.target] and the result of
[`get_content`][vllm.multimodal.processing.BoundPromptUpdate.get_content]
between token sequence and text representations.
"""
_origin: PromptUpdate
tokenizer: AnyTokenizer = field(repr=False)
......@@ -482,7 +493,8 @@ class BoundPromptUpdate:
def get_content(self, item_idx: int) -> _BoundPromptContent:
"""
Given the index of the processed item within {attr}`modality`,
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the token sequence (or text) to update.
"""
content = self.content
......@@ -1019,7 +1031,8 @@ class ProcessingCache:
) -> None:
"""
Put a processed multi-modal item into the cache
according to its dependencies (see {meth}`get`).
according to its dependencies
(see [`get`][vllm.multimodal.processing.ProcessingCache.get]).
"""
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
......@@ -1091,7 +1104,8 @@ _I = TypeVar("_I", bound=BaseProcessingInfo)
MultiModalHashes = dict[str, list[str]]
"""
A collection of hashes with a similar structure as {class}`MultiModalKwargs`.
A collection of hashes with a similar structure as
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
"""
......@@ -1099,7 +1113,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
Abstract base class to process multi-modal inputs to be used in vLLM.
Not to be confused with {class}`transformers.ProcessorMixin`.
Not to be confused with `transformers.ProcessorMixin`.
"""
def __init__(self,
......@@ -1126,10 +1140,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _get_data_parser(self) -> MultiModalDataParser:
"""
Construct a parser to preprocess multi-modal data items
before passing them to {meth}`_get_hf_mm_data`.
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
You can support additional modalities by creating a subclass
of {class}`MultiModalDataParser` that has additional subparsers.
of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
that has additional subparsers.
"""
return MultiModalDataParser()
......@@ -1138,8 +1154,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
"""
Normalize {class}`MultiModalDataDict` to {class}`MultiModalDataItems`
before passing them to {meth}`_get_hf_mm_data`.
Normalize
[`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
supported_mm_limits = self.info.get_supported_mm_limits()
......@@ -1191,7 +1210,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
inputs.
Moreover, this information is critical to determine the token positions
in order to construct {class}`~vllm-multimodal.input.PlaceholderRange`
in order to construct
[`PlaceholderRange`][vllm.multimodal.inputs.PlaceholderRange]
for each multi-modal item.
"""
raise NotImplementedError
......@@ -1315,7 +1335,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Most HF processors accept prompt text but not prompt tokens.
If the HF processor adds or removes tokens that are not related to
multi-modal data, you should override this method so it is consistent
with the output of {meth}`_apply_hf_processor_text_only` on the
with the output of
[`_apply_hf_processor_text_only`][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_text_only]
on the
corresponding text.
"""
return prompt_tokens
......@@ -1330,7 +1352,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Since HF processor requires that text and multi-modal items
correspond to each other, we generate dummy text using
{class}`DummyInputsBuilder` to go along with the multi-modal data.
[`DummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
to go along with the multi-modal data.
"""
mm_counts = mm_items.get_all_counts()
......
......@@ -3,7 +3,7 @@
from abc import ABC
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, NamedTuple, Optional, TypeVar, cast
from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast
import numpy as np
import numpy.typing as npt
......@@ -25,9 +25,9 @@ logger = init_logger(__name__)
class ProcessorInputs:
"""
Represents the keyword arguments to
{meth}`vllm.multimodal.processing.BaseMultiModalProcessor.apply`.
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
"""
prompt_text: str
prompt: Union[str, list[int]]
mm_data: MultiModalDataDict
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
......@@ -75,7 +75,12 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
"in an upcoming release.")
seq_len = self.info.ctx.model_config.max_model_len
return self.get_dummy_processor_inputs(seq_len, mm_counts).prompt_text
prompt = self.get_dummy_processor_inputs(seq_len, mm_counts).prompt
if not isinstance(prompt, str):
prompt = self.info.get_tokenizer().decode(prompt)
return prompt
# TODO: @abstractmethod after transition
def get_dummy_mm_data(
......@@ -101,7 +106,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
return ProcessorInputs(prompt_text=dummy_text, mm_data=dummy_mm_data)
return ProcessorInputs(prompt=dummy_text, mm_data=dummy_mm_data)
def _get_dummy_audios(
self,
......@@ -177,7 +182,7 @@ class MultiModalProfiler(Generic[_I]):
seq_len, mm_counts)
return self.processor.apply(
prompt=processor_inputs.prompt_text,
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
......
......@@ -29,7 +29,11 @@ _I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
class ProcessingInfoFactory(Protocol[_I_co]):
"""Constructs a {class}`MultiModalProcessor` instance from the context."""
"""
Constructs a
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
instance from the context.
"""
def __call__(
self,
......@@ -40,7 +44,9 @@ class ProcessingInfoFactory(Protocol[_I_co]):
class DummyInputsBuilderFactory(Protocol[_I]):
"""
Constructs a {class}`BaseDummyInputsBuilder` instance from the context.
Constructs a
[`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
instance from the context.
"""
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
......@@ -48,7 +54,11 @@ class DummyInputsBuilderFactory(Protocol[_I]):
class MultiModalProcessorFactory(Protocol[_I]):
"""Constructs a {class}`MultiModalProcessor` instance from the context."""
"""
Constructs a
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
instance from the context.
"""
def __call__(
self,
......@@ -155,8 +165,6 @@ class MultiModalRegistry:
"""
Get the maximum number of tokens from each modality
for profiling the memory usage of a model.
See {meth}`MultiModalPlugin.get_max_multimodal_tokens` for more details.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config)
......@@ -170,8 +178,6 @@ class MultiModalRegistry:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
See {meth}`MultiModalPlugin.get_max_multimodal_tokens` for more details.
"""
return sum(self.get_max_tokens_by_modality(model_config).values())
......@@ -213,10 +219,6 @@ class MultiModalRegistry:
When the model receives multi-modal data, the provided function is
invoked to transform the data into a dictionary of model inputs.
:::{seealso}
{ref}`mm-processing`
:::
"""
def wrapper(model_cls: N) -> N:
......@@ -259,10 +261,6 @@ class MultiModalRegistry:
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""
Create a multi-modal processor for a specific model and tokenizer.
:::{seealso}
{ref}`mm-processing`
:::
"""
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")
......
......@@ -259,7 +259,8 @@ class MediaConnector:
global_media_connector = MediaConnector()
"""The global {class}`MediaConnector` instance used by vLLM."""
"""The global [`MediaConnector`][vllm.multimodal.utils.MediaConnector]
instance used by vLLM."""
fetch_audio = global_media_connector.fetch_audio
fetch_image = global_media_connector.fetch_image
......
......@@ -164,7 +164,7 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
)
return np.stack([
np.array(load_frame(frame_data))
np.asarray(load_frame(frame_data))
for frame_data in data.split(",")
])
......
......@@ -9,12 +9,15 @@ from typing import Any, Generic, Optional, Union
import torch
from typing_extensions import TypeVar, deprecated
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceGroupBase, SequenceStatus)
logger = init_logger(__name__)
@dataclass
class CompletionOutput:
......@@ -122,7 +125,13 @@ class RequestOutput:
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
kv_transfer_params: Optional[dict[str, Any]] = None,
# Forward compatibility, code that uses args added in new release can
# still run with older versions of vLLM without breaking.
**kwargs: Any,
) -> None:
if kwargs:
logger.warning_once("RequestOutput: Ignoring extra arguments: %s",
str(kwargs))
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
......@@ -382,15 +391,6 @@ class PoolingRequestOutput(Generic[_O]):
prompt_token_ids, finished)
def __repr__(self):
"""
Returns a string representation of an PoolingRequestOutput instance.
The representation includes the request_id and the number of outputs,
providing a quick overview of the pooling request's results.
Returns:
str: A string representation of the PoolingRequestOutput instance.
"""
return (f"{type(self).__name__}(request_id={self.request_id!r}, "
f"outputs={self.outputs!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
......
......@@ -42,7 +42,6 @@ def tpu_platform_plugin() -> Optional[str]:
logger.debug("Confirmed TPU platform is available.")
except Exception as e:
logger.debug("TPU platform is not available because: %s", str(e))
pass
return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None
......@@ -112,7 +111,6 @@ def rocm_platform_plugin() -> Optional[str]:
amdsmi.amdsmi_shut_down()
except Exception as e:
logger.debug("ROCm platform is not available because: %s", str(e))
pass
return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None
......@@ -130,7 +128,6 @@ def hpu_platform_plugin() -> Optional[str]:
"habana_frameworks is not found.")
except Exception as e:
logger.debug("HPU platform is not available because: %s", str(e))
pass
return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None
......@@ -148,7 +145,6 @@ def xpu_platform_plugin() -> Optional[str]:
logger.debug("Confirmed XPU platform is available.")
except Exception as e:
logger.debug("XPU platform is not available because: %s", str(e))
pass
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
......@@ -170,7 +166,6 @@ def cpu_platform_plugin() -> Optional[str]:
except Exception as e:
logger.debug("CPU platform is not available because: %s", str(e))
pass
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
......
......@@ -9,6 +9,7 @@ import psutil
import torch
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
......@@ -74,7 +75,7 @@ class CpuPlatform(Platform):
import vllm.envs as envs
from vllm.utils import GiB_bytes
model_config = vllm_config.model_config
# Reminder: Please update docs/source/features/compatibility_matrix.md
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
if not model_config.enforce_eager:
model_config.enforce_eager = True
......@@ -177,6 +178,16 @@ class CpuPlatform(Platform):
" set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.")
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on CPU.")
......
......@@ -158,6 +158,7 @@ class CudaPlatformBase(Platform):
"currently not supported with CUDA Graphs.")
vllm_config.model_config.enforce_eager = True
compilation_config.use_cudagraph = False
# FIXME: inductor breaks cudagraph (from @bnell)
compilation_config.use_inductor = False
@classmethod
......@@ -311,6 +312,10 @@ class CudaPlatformBase(Platform):
def use_custom_allreduce(cls) -> bool:
return True
@classmethod
def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
......
......@@ -7,6 +7,7 @@ import torch
from vllm import envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum, _Backend
......@@ -38,8 +39,8 @@ class HpuPlatform(Platform):
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
@staticmethod
def inference_mode():
@classmethod
def inference_mode(cls):
return torch.no_grad()
@classmethod
......@@ -80,6 +81,16 @@ class HpuPlatform(Platform):
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on HPU.")
......
......@@ -84,7 +84,7 @@ class DeviceCapability(NamedTuple):
def to_int(self) -> int:
"""
Express device capability as an integer ``<major><minor>``.
Express device capability as an integer `<major><minor>`.
It is assumed that the minor version is always a single digit.
"""
......@@ -157,7 +157,7 @@ class Platform:
return self._enum == PlatformEnum.OOT
def is_cuda_alike(self) -> bool:
"""Stateless version of {func}`torch.cuda.is_available`."""
"""Stateless version of [torch.cuda.is_available][]."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
def is_sleep_mode_available(self) -> bool:
......@@ -194,7 +194,7 @@ class Platform:
cls,
device_id: int = 0,
) -> Optional[DeviceCapability]:
"""Stateless version of {func}`torch.cuda.get_device_capability`."""
"""Stateless version of [torch.cuda.get_device_capability][]."""
return None
@classmethod
......@@ -206,10 +206,11 @@ class Platform:
"""
Test whether this platform is compatible with a device capability.
The ``capability`` argument can either be:
The `capability` argument can either be:
- A tuple ``(major, minor)``.
- An integer ``<major><minor>``. (See {meth}`DeviceCapability.to_int`)
- A tuple `(major, minor)`.
- An integer `<major><minor>`. (See
[`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
"""
current_capability = cls.get_device_capability(device_id=device_id)
if current_capability is None:
......@@ -478,6 +479,13 @@ class Platform:
"""
raise NotImplementedError
@classmethod
def get_piecewise_backend_cls(cls) -> str:
"""
Get piecewise backend class for piecewise graph.
"""
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
......
......@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum
......@@ -56,6 +57,16 @@ class NeuronPlatform(Platform):
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len # type: ignore
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Neuron.")
......
......@@ -102,27 +102,43 @@ def on_mi250_mi300() -> bool:
@cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
def use_rocm_custom_paged_attention(
qtype: torch.dtype,
head_size: int,
block_size: int,
gqa_ratio: int,
max_seq_len: int,
sliding_window: int) -> bool:
sliding_window: int,
kv_cache_dtype: str,
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# rocm custom page attention not support on gfx1*
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0
if ON_GFX9:
return ((not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER))
else:
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 32768 and alibi_slopes is None
and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
......@@ -201,9 +217,9 @@ class RocmPlatform(Platform):
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
@staticmethod
@classmethod
@with_amdsmi_context
def is_fully_connected(physical_device_ids: list[int]) -> bool:
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
"""
Query if the set of gpus are fully connected by xgmi (1 hop)
"""
......@@ -363,3 +379,11 @@ class RocmPlatform(Platform):
def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(
device_id).multi_processor_count
@classmethod
def is_navi(cls) -> bool:
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
@classmethod
def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
......@@ -9,6 +9,7 @@ import vllm.envs as envs
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum, _Backend
......@@ -161,6 +162,16 @@ class TpuPlatform(Platform):
"Forcing --disable_chunked_mm_input.")
scheduler_config.disable_chunked_mm_input = True
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")
......
......@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
import torch
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
......@@ -36,15 +37,17 @@ class XPUPlatform(Platform):
logger.info("Using IPEX attention backend.")
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
@staticmethod
@classmethod
def get_device_capability(
device_id: int = 0) -> Optional[DeviceCapability]:
cls,
device_id: int = 0,
) -> Optional[DeviceCapability]:
# capacity format differs from cuda's and will cause unexpected
# failure, so use None directly
return None
@staticmethod
def get_device_name(device_id: int = 0) -> str:
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return torch.xpu.get_device_name(device_id)
@classmethod
......@@ -56,8 +59,8 @@ class XPUPlatform(Platform):
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
@staticmethod
def inference_mode():
@classmethod
def inference_mode(cls):
return torch.no_grad()
@classmethod
......@@ -113,6 +116,16 @@ class XPUPlatform(Platform):
parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend = "ray"
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on XPU.")
......
......@@ -2,7 +2,7 @@
import logging
import os
from typing import Callable
from typing import Any, Callable
import torch
......@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
plugins_loaded = False
def load_plugins_by_group(group: str) -> dict[str, Callable]:
def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]:
import sys
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
......@@ -27,23 +27,27 @@ def load_plugins_by_group(group: str) -> dict[str, Callable]:
if len(discovered_plugins) == 0:
logger.debug("No plugins for group %s found.", group)
return {}
logger.info("Available plugins for group %s:", group)
for plugin in discovered_plugins:
logger.info("name=%s, value=%s", plugin.name, plugin.value)
logger.info("- %s -> %s", plugin.name, plugin.value)
if allowed_plugins is None:
logger.info("all available plugins for group %s will be loaded.",
group)
logger.info("set environment variable VLLM_PLUGINS to control"
" which plugins to load.")
plugins = {}
logger.info("All plugins in this group will be loaded. "
"Set `VLLM_PLUGINS` to control which plugins to load.")
plugins = dict[str, Callable[[], Any]]()
for plugin in discovered_plugins:
if allowed_plugins is None or plugin.name in allowed_plugins:
if allowed_plugins is not None:
logger.info("Loading plugin %s", plugin.name)
try:
func = plugin.load()
plugins[plugin.name] = func
logger.info("plugin %s loaded.", plugin.name)
except Exception:
logger.exception("Failed to load plugin %s", plugin.name)
return plugins
......
# SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Sequence
from typing import Optional, Union
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
......
......@@ -27,7 +27,7 @@ VLLM_INVALID_TOKEN_ID = -1
def array_full(token_id: int, count: int):
"""{class}`array` equivalent of {func}`numpy.full`."""
"""[`array`][] equivalent of [numpy.full][]."""
return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
......@@ -192,8 +192,8 @@ class SequenceData(msgspec.Struct,
def from_prompt_token_counts(
*token_counts: tuple[int, int]) -> "SequenceData":
"""
Construct a {class}`SequenceData` instance by concatenating
prompt token sequences.
Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
by concatenating prompt token sequences.
Each tuple represents one token sequence, expressed in the form
`(token_id, count)`.
......@@ -216,8 +216,8 @@ class SequenceData(msgspec.Struct,
prompt_embeds: Optional[torch.Tensor] = None,
) -> "SequenceData":
"""
Construct a {class}`SequenceData` instance from prompt and output
token sequences.
Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
from prompt and output token sequences.
"""
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
prompt_token_ids)
......@@ -452,9 +452,11 @@ class SequenceData(msgspec.Struct,
class Sequence:
"""Stores the data, status, and block information of a sequence.
The sequence is constructed from the {data}`DecoderOnlyInputs`
(for decoder-only) or {data}`EncoderDecoderInputs` (for encoder-decoder)
instance passed in through the `inputs` constructor argument.
The sequence is constructed from the
[`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only)
or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
(for encoder-decoder) instance passed in through the `inputs`
constructor argument.
Args:
seq_id: The ID of the sequence.
......@@ -1123,7 +1125,7 @@ class SequenceOutput(
self.output_embed.shape if self.output_embed is not None else None
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, "
f"output_embed.shape={output_embed_shape}"
f"output_embed.shape={output_embed_shape}, "
f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool:
......@@ -1494,7 +1496,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i
params = copy.deepcopy(original_params)
params = original_params.clone()
params.n = 1
if params.seed is not None:
params.seed += i
......
......@@ -294,8 +294,11 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
inputs_embeds=None,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_runner.model_config.dtype,
device=self.device,
),
**model_execute_kwargs,
)
......
......@@ -126,12 +126,12 @@ class AsyncMetricsCollector:
"""Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.
Returns a CUDA event recording when the copy is complete.
Returns a device event recording when the copy is complete.
"""
assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.cuda.current_stream())
self._copy_stream.wait_stream(current_platform.current_stream())
with torch.cuda.stream(self._copy_stream):
with current_platform.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_(
self.spec_decode_sampler.num_accepted_tokens,
non_blocking=True)
......@@ -142,7 +142,7 @@ class AsyncMetricsCollector:
self._aggregate_num_draft_tokens = (
self.spec_decode_sampler.num_draft_tokens)
aggregate_metrics_ready = torch.cuda.Event()
aggregate_metrics_ready = current_platform.Event()
aggregate_metrics_ready.record(self._copy_stream)
return aggregate_metrics_ready
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment