Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
...@@ -2,11 +2,12 @@ from functools import lru_cache ...@@ -2,11 +2,12 @@ from functools import lru_cache
import torch import torch
from PIL import Image from PIL import Image
from transformers.image_processing_base import BatchFeature
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_image_processor from vllm.transformers_utils.processor import get_image_processor
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
...@@ -23,9 +24,14 @@ class ImagePlugin(MultiModalPlugin): ...@@ -23,9 +24,14 @@ class ImagePlugin(MultiModalPlugin):
return "image" return "image"
def _get_hf_image_processor(self, model_config: ModelConfig): def _get_hf_image_processor(self, model_config: ModelConfig):
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
else model_config.mm_processor_kwargs)
# We don't explicitly check kwarg overrides to the HF class
# since the automodel just takes kwargs, so we can't inspect it
return cached_get_image_processor( return cached_get_image_processor(
model_config.model, model_config.model,
trust_remote_code=model_config.trust_remote_code) trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
def _default_input_mapper( def _default_input_mapper(
self, self,
...@@ -34,9 +40,14 @@ class ImagePlugin(MultiModalPlugin): ...@@ -34,9 +40,14 @@ class ImagePlugin(MultiModalPlugin):
) -> MultiModalInputs: ) -> MultiModalInputs:
model_config = ctx.model_config model_config = ctx.model_config
# Processed by input processor
if isinstance(data, BatchFeature):
return MultiModalInputs(data.data)
# PIL image # PIL image
if isinstance(data, Image.Image) or is_list_of(data, Image.Image): if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config) image_processor = self._get_hf_image_processor(model_config)
if image_processor is None: if image_processor is None:
raise RuntimeError("No HuggingFace processor is available " raise RuntimeError("No HuggingFace processor is available "
"to process the image object") "to process the image object")
......
...@@ -138,6 +138,15 @@ class MultiModalRegistry: ...@@ -138,6 +138,15 @@ class MultiModalRegistry:
""" """
Create an input mapper (see :meth:`map_input`) for a specific model. Create an input mapper (see :meth:`map_input`) for a specific model.
""" """
# NOTE - we currently make the assumption that if a model has multiple
# supported modalities, they take the same kwargs. For the default,
# this could be an issue in the future if it falls back to two HF
# resources and we can't inspect the signature easily since it's
# getting initialized through the autoclass.
#
# If this is a problem in the future, we should revisit it, but since
# it potentially introduces a lot of complexity for a currently
# uncommon case, we do not for simplicity of both use & implementation
return functools.partial(self.map_input, model_config) return functools.partial(self.map_input, model_config)
def register_max_multimodal_tokens( def register_max_multimodal_tokens(
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_video_processor from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -37,9 +37,14 @@ class VideoPlugin(ImagePlugin): ...@@ -37,9 +37,14 @@ class VideoPlugin(ImagePlugin):
return "video" return "video"
def _get_hf_video_processor(self, model_config: ModelConfig): def _get_hf_video_processor(self, model_config: ModelConfig):
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
else model_config.mm_processor_kwargs)
# We don't explicitly check kwarg overrides to the HF class
# since the automodel just takes kwargs, so we can't inspect it
return cached_get_video_processor( return cached_get_video_processor(
model_config.model, model_config.model,
trust_remote_code=model_config.trust_remote_code) trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
def _default_input_mapper( def _default_input_mapper(
self, self,
......
...@@ -114,17 +114,28 @@ class RequestOutput: ...@@ -114,17 +114,28 @@ class RequestOutput:
self.encoder_prompt_token_ids = encoder_prompt_token_ids self.encoder_prompt_token_ids = encoder_prompt_token_ids
@classmethod @classmethod
def from_seq_group(cls, def from_seq_group(cls, seq_group: SequenceGroup,
seq_group: SequenceGroup) -> Optional["RequestOutput"]: use_cache: bool) -> Optional["RequestOutput"]:
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
if sampling_params is None: if sampling_params is None:
raise ValueError( raise ValueError(
"Sampling parameters are missing for a CompletionRequest.") "Sampling parameters are missing for a CompletionRequest.")
finished = seq_group.is_finished() finished = seq_group.is_finished()
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished): not finished):
return None return None
# Init cache (if needed)
if use_cache and seq_group.cached_request_output is None:
seq_group.cached_request_output = RequestOutput( # type: ignore
request_id="",
prompt=None,
prompt_token_ids=[],
prompt_logprobs=None,
outputs=[],
finished=False)
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
if len(seqs) == 1: if len(seqs) == 1:
top_n_seqs = seqs top_n_seqs = seqs
...@@ -149,29 +160,66 @@ class RequestOutput: ...@@ -149,29 +160,66 @@ class RequestOutput:
outputs = [] outputs = []
include_prompt = True include_prompt = True
for seq in top_n_seqs: for i, seq in enumerate(top_n_seqs):
output_text = seq.get_output_text_to_return( output_text = seq.get_output_text_to_return(
text_buffer_length, delta) text_buffer_length, delta)
output_token_ids = seq.get_output_token_ids_to_return(delta) output_token_ids = seq.get_output_token_ids_to_return(delta)
num_output_tokens = 1 if isinstance(output_token_ids,
int) else len(output_token_ids)
output_logprobs = seq.output_logprobs if include_logprobs else None output_logprobs = seq.output_logprobs if include_logprobs else None
if delta: if delta:
# Slice logprobs delta if applicable # Slice logprobs delta if applicable
if output_logprobs: if output_logprobs:
output_logprobs = output_logprobs[-len(output_token_ids):] output_logprobs = output_logprobs[-num_output_tokens:]
# Don't include prompt if this is after the first output # Don't include prompt if this is after the first output
# containing decode token ids # containing decode token ids
if include_prompt and seq.get_output_len() > len( if include_prompt and seq.get_output_len() > num_output_tokens:
output_token_ids):
include_prompt = False include_prompt = False
outputs.append( if use_cache:
CompletionOutput( # Get cached output object
seqs.index(seq), output_text, output_token_ids, cached_outputs = seq_group.cached_request_output.outputs # type: ignore
if i >= len(cached_outputs):
cached_outputs.append(
CompletionOutput(index=i,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None))
output = cached_outputs[i]
# Init cached output object
assert output.index == i
output.text = output_text
if isinstance(output_token_ids, int):
output.token_ids.clear()
output.token_ids.append(output_token_ids)
else:
output.token_ids = output_token_ids
output.cumulative_logprob = seq.get_cumulative_logprob() \
if include_logprobs else None
output.logprobs = output_logprobs
output.finish_reason = SequenceStatus.get_finished_reason(
seq.status)
output.stop_reason = seq.stop_reason
else:
output = CompletionOutput(
seqs.index(seq), output_text, [output_token_ids]
if isinstance(output_token_ids, int) else output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None, seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs, output_logprobs,
SequenceStatus.get_finished_reason(seq.status), SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason)) seq.stop_reason)
outputs.append(output)
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.
if include_prompt: if include_prompt:
...@@ -188,16 +236,20 @@ class RequestOutput: ...@@ -188,16 +236,20 @@ class RequestOutput:
prompt_logprobs = None prompt_logprobs = None
finished_time = time.time() if finished else None finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time) seq_group.set_finished_time(finished_time)
return cls(seq_group.request_id,
prompt, init_args = (seq_group.request_id, prompt, prompt_token_ids,
prompt_token_ids, prompt_logprobs, outputs, finished, seq_group.metrics,
prompt_logprobs, seq_group.lora_request, encoder_prompt,
outputs, encoder_prompt_token_ids)
finished,
seq_group.metrics, if use_cache:
lora_request=seq_group.lora_request, request_output = seq_group.cached_request_output
encoder_prompt=encoder_prompt, request_output.__init__(*init_args) # type: ignore
encoder_prompt_token_ids=encoder_prompt_token_ids)
else:
request_output = cls(*init_args)
return request_output
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "
...@@ -261,10 +313,10 @@ class EmbeddingRequestOutput: ...@@ -261,10 +313,10 @@ class EmbeddingRequestOutput:
class RequestOutputFactory: class RequestOutputFactory:
@staticmethod @staticmethod
def create(seq_group): def create(seq_group: SequenceGroup, use_cache: bool = False):
# Determine the type based on a condition, for example: # Determine the type based on a condition, for example:
if hasattr(seq_group, if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None: 'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group) return EmbeddingRequestOutput.from_seq_group(seq_group)
else: else:
return RequestOutput.from_seq_group(seq_group) return RequestOutput.from_seq_group(seq_group, use_cache)
...@@ -6,10 +6,10 @@ from .interface import Platform, PlatformEnum ...@@ -6,10 +6,10 @@ from .interface import Platform, PlatformEnum
class CpuPlatform(Platform): class CpuPlatform(Platform):
_enum = PlatformEnum.CPU _enum = PlatformEnum.CPU
@staticmethod @classmethod
def get_device_name(device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
return "cpu" return "cpu"
@staticmethod @classmethod
def inference_mode(): def inference_mode(cls):
return torch.no_grad() return torch.no_grad()
...@@ -11,7 +11,7 @@ from typing_extensions import ParamSpec ...@@ -11,7 +11,7 @@ from typing_extensions import ParamSpec
from vllm.logger import init_logger from vllm.logger import init_logger
from .interface import Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -96,19 +96,20 @@ def device_id_to_physical_device_id(device_id: int) -> int: ...@@ -96,19 +96,20 @@ def device_id_to_physical_device_id(device_id: int) -> int:
class CudaPlatform(Platform): class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA _enum = PlatformEnum.CUDA
@staticmethod @classmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]: def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
physical_device_id = device_id_to_physical_device_id(device_id) physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_capability(physical_device_id) major, minor = get_physical_device_capability(physical_device_id)
return DeviceCapability(major=major, minor=minor)
@staticmethod @classmethod
def get_device_name(device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id) physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_name(physical_device_id) return get_physical_device_name(physical_device_id)
@staticmethod @classmethod
@with_nvml_context @with_nvml_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool: def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
""" """
query if the set of gpus are fully connected by nvlink (1 hop) query if the set of gpus are fully connected by nvlink (1 hop)
""" """
......
import enum import enum
from typing import Optional, Tuple from typing import NamedTuple, Optional, Tuple, Union
import torch import torch
...@@ -12,6 +12,23 @@ class PlatformEnum(enum.Enum): ...@@ -12,6 +12,23 @@ class PlatformEnum(enum.Enum):
UNSPECIFIED = enum.auto() UNSPECIFIED = enum.auto()
class DeviceCapability(NamedTuple):
major: int
minor: int
def as_version_str(self) -> str:
return f"{self.major}.{self.minor}"
def to_int(self) -> int:
"""
Express device capability as an integer ``<major><minor>``.
It is assumed that the minor version is always a single digit.
"""
assert 0 <= self.minor < 10
return self.major * 10 + self.minor
class Platform: class Platform:
_enum: PlatformEnum _enum: PlatformEnum
...@@ -27,16 +44,47 @@ class Platform: ...@@ -27,16 +44,47 @@ class Platform:
def is_cpu(self) -> bool: def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU return self._enum == PlatformEnum.CPU
@staticmethod def is_cuda_alike(self) -> bool:
def get_device_capability(device_id: int = 0) -> Optional[Tuple[int, int]]: """Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
@classmethod
def get_device_capability(
cls,
device_id: int = 0,
) -> Optional[DeviceCapability]:
"""Stateless version of :func:`torch.cuda.get_device_capability`."""
return None return None
@staticmethod @classmethod
def get_device_name(device_id: int = 0) -> str: def has_device_capability(
cls,
capability: Union[Tuple[int, int], int],
device_id: int = 0,
) -> bool:
"""
Test whether this platform is compatible with a device capability.
The ``capability`` argument can either be:
- A tuple ``(major, minor)``.
- An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
"""
current_capability = cls.get_device_capability(device_id=device_id)
if current_capability is None:
return False
if isinstance(capability, tuple):
return current_capability >= capability
return current_capability.to_int() >= capability
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError raise NotImplementedError
@staticmethod @classmethod
def inference_mode(): def inference_mode(cls):
"""A device-specific wrapper of `torch.inference_mode`. """A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU This wrapper is recommended because some hardware backends such as TPU
......
import os import os
from functools import lru_cache from functools import lru_cache
from typing import Tuple
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from .interface import Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -20,12 +19,13 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: ...@@ -20,12 +19,13 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class RocmPlatform(Platform): class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM _enum = PlatformEnum.ROCM
@staticmethod @classmethod
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def get_device_capability(device_id: int = 0) -> Tuple[int, int]: def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
return torch.cuda.get_device_capability(device_id) major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
@staticmethod @classmethod
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def get_device_name(device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id) return torch.cuda.get_device_name(device_id)
...@@ -6,6 +6,10 @@ from .interface import Platform, PlatformEnum ...@@ -6,6 +6,10 @@ from .interface import Platform, PlatformEnum
class TpuPlatform(Platform): class TpuPlatform(Platform):
_enum = PlatformEnum.TPU _enum = PlatformEnum.TPU
@staticmethod @classmethod
def inference_mode(): def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
@classmethod
def inference_mode(cls):
return torch.no_grad() return torch.no_grad()
...@@ -8,13 +8,15 @@ from huggingface_hub import file_exists, hf_hub_download ...@@ -8,13 +8,15 @@ from huggingface_hub import file_exists, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError from huggingface_hub.utils import EntryNotFoundError
from safetensors.torch import load_file as safe_load_file from safetensors.torch import load_file as safe_load_file
from vllm.platforms import current_platform
WEIGHTS_NAME = "adapter_model.bin" WEIGHTS_NAME = "adapter_model.bin"
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
# Get current device name based on available devices # Get current device name based on available devices
def infer_device() -> str: def infer_device() -> str:
if torch.cuda.is_available(): if current_platform.is_cuda_alike():
return "cuda" return "cuda"
return "cpu" return "cpu"
......
...@@ -8,6 +8,7 @@ import msgspec ...@@ -8,6 +8,7 @@ import msgspec
import torch import torch
from typing_extensions import Annotated from typing_extensions import Annotated
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -260,6 +261,10 @@ class SamplingParams( ...@@ -260,6 +261,10 @@ class SamplingParams(
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
if not envs.VLLM_ALLOW_DEPRECATED_BEAM_SEARCH:
raise ValueError(
"Using beam search as a sampling parameter is deprecated, and will be removed in the future release. Please use the `vllm.LLM.use_beam_search` method for dedicated beam search instead, or set the environment variable `VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1` to suppress this error. For more details, see https://github.com/vllm-project/vllm/issues/8306 ." # noqa
)
self._verify_beam_search() self._verify_beam_search()
else: else:
self._verify_non_beam_search() self._verify_non_beam_search()
...@@ -273,9 +278,14 @@ class SamplingParams( ...@@ -273,9 +278,14 @@ class SamplingParams(
self._all_stop_token_ids = set(self.stop_token_ids) self._all_stop_token_ids = set(self.stop_token_ids)
def _verify_args(self) -> None: def _verify_args(self) -> None:
if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of "
f"type {type(self.n)}")
if self.n < 1: if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.") raise ValueError(f"n must be at least 1, got {self.n}.")
assert isinstance(self.best_of, int) if not isinstance(self.best_of, int):
raise ValueError(f'best_of must be an int, but is of '
f'type {type(self.best_of)}')
if self.best_of < self.n: if self.best_of < self.n:
raise ValueError(f"best_of must be greater than or equal to n, " raise ValueError(f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.") f"got n={self.n} and best_of={self.best_of}.")
......
# The CLI entrypoint to vLLM. # The CLI entrypoint to vLLM.
import argparse import argparse
import asyncio
import os import os
import signal import signal
import sys import sys
from typing import List, Optional from typing import List, Optional
import uvloop
from openai import OpenAI from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
...@@ -34,7 +34,7 @@ def serve(args: argparse.Namespace) -> None: ...@@ -34,7 +34,7 @@ def serve(args: argparse.Namespace) -> None:
# EngineArgs expects the model name to be passed as --model. # EngineArgs expects the model name to be passed as --model.
args.model = args.model_tag args.model = args.model_tag
asyncio.run(run_server(args)) uvloop.run(run_server(args))
def interactive_cli(args: argparse.Namespace) -> None: def interactive_cli(args: argparse.Namespace) -> None:
......
...@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod ...@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from array import array from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property, reduce
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast from typing import Set, Tuple, Union, cast
...@@ -12,6 +13,7 @@ from typing import Set, Tuple, Union, cast ...@@ -12,6 +13,7 @@ from typing import Set, Tuple, Union, cast
import msgspec import msgspec
import torch import torch
from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -20,11 +22,12 @@ from vllm.sampling_params import SamplingParams ...@@ -20,11 +22,12 @@ from vllm.sampling_params import SamplingParams
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.inputs import LLMInputs
from vllm.multimodal.base import MultiModalDataDict from vllm.multimodal.base import MultiModalDataDict
VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_TOKEN_ID_ARRAY_TYPE = "l"
VLLM_INVALID_TOKEN_ID = -1
# We use dataclass for now because it is used for # We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable. # openai server output, and msgspec is not serializable.
...@@ -169,6 +172,35 @@ class SequenceData(msgspec.Struct, ...@@ -169,6 +172,35 @@ class SequenceData(msgspec.Struct,
# It is used to compute mrope_position_ids. # It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None _mrope_position_delta: Optional[int] = None
@staticmethod
def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
if len(token_counts) == 0:
return SequenceData.from_seqs([])
arrs = [
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
for token_id, count in token_counts
]
return SequenceData(reduce(array.__add__, arrs))
@staticmethod
def from_seqs(
prompt_token_ids: GenericSequence[int],
output_token_ids: Optional[GenericSequence[int]] = None,
) -> "SequenceData":
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
prompt_token_ids)
if output_token_ids is None:
return SequenceData(prompt_token_ids_arr)
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
output_token_ids)
return SequenceData(prompt_token_ids_arr,
_output_token_ids=output_token_ids_arr)
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l" assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l"
...@@ -370,8 +402,6 @@ class Sequence: ...@@ -370,8 +402,6 @@ class Sequence:
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.from_decoder_prompt = from_decoder_prompt self.from_decoder_prompt = from_decoder_prompt
self._prompt: Optional[str] = None
self._prompt_token_ids: Optional[List[int]] = None
# For decoder-only models, a Sequence is constructed # For decoder-only models, a Sequence is constructed
# from an LLMInputs instance (the `inputs` arg.) # from an LLMInputs instance (the `inputs` arg.)
...@@ -400,8 +430,7 @@ class Sequence: ...@@ -400,8 +430,7 @@ class Sequence:
f"invalid input {inputs}; did you forget the " f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?") "encoder input prompt fields?")
self.data = SequenceData( self.data = SequenceData.from_seqs(self.prompt_token_ids)
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
self.output_text = "" self.output_text = ""
...@@ -409,7 +438,7 @@ class Sequence: ...@@ -409,7 +438,7 @@ class Sequence:
self.stop_reason: Union[int, str, None] = None self.stop_reason: Union[int, str, None] = None
# These are used to keep track of delta outputs # These are used to keep track of delta outputs
self._last_token_ids_offset: int = 0 self._last_output_token_ids_offset: int = 0
self._last_output_text_offset: int = 0 self._last_output_text_offset: int = 0
# Used for incremental detokenization # Used for incremental detokenization
...@@ -422,41 +451,35 @@ class Sequence: ...@@ -422,41 +451,35 @@ class Sequence:
def n_blocks(self) -> int: def n_blocks(self) -> int:
return (self.get_len() + self.block_size - 1) // self.block_size return (self.get_len() + self.block_size - 1) // self.block_size
@property @cached_property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
if self._prompt is not None: # Select decoder or encoder input prompt str, as appropriate
# Reuse precomputed prompt string
return self._prompt
# Select decoder or encoder input prompt str,
# as appropriate
prompt_key: str = ("prompt" prompt_key: str = ("prompt"
if self.from_decoder_prompt else "encoder_prompt") if self.from_decoder_prompt else "encoder_prompt")
# Cache prompt return cast(Optional[str], self.inputs.get(prompt_key))
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
return self._prompt
@property @cached_property
def prompt_token_ids(self) -> List[int]: def prompt_token_ids(self) -> List[int]:
if self._prompt_token_ids is not None: # Select decoder or encoder input prompt token ids, as appropriate
# Reuse precomputed prompt token ids
return self._prompt_token_ids
# Select decoder or encoder input prompt
# token ids, as appropriate
prompt_token_ids_key: str = ("prompt_token_ids" prompt_token_ids_key: str = ("prompt_token_ids"
if self.from_decoder_prompt else if self.from_decoder_prompt else
"encoder_prompt_token_ids") "encoder_prompt_token_ids")
# Cache computed prompt token ids # Cache computed prompt token ids
self._prompt_token_ids = cast(List[int], return cast(List[int], self.inputs.get(prompt_token_ids_key))
self.inputs.get(prompt_token_ids_key))
return self._prompt_token_ids
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> "MultiModalDataDict":
return self.inputs.get("multi_modal_data") or {} if self.inputs.get("multi_modal_data") and self.inputs.get(
"encoder_multi_modal_data"):
raise ValueError(
"Multi-modal data in both encoder and decoder is not supported."
)
inputs = self.inputs
return self.inputs.get("multi_modal_data") or (cast(
EncoderDecoderLLMInputs,
inputs).get("encoder_multi_modal_data")) or {}
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
...@@ -486,18 +509,26 @@ class Sequence: ...@@ -486,18 +509,26 @@ class Sequence:
return self.output_text[last_offset:length] return self.output_text[last_offset:length]
return "" return ""
def get_output_token_ids_to_return(self, def get_output_token_ids_to_return(
delta: bool) -> GenericSequence[int]: self, delta: bool) -> Union[GenericSequence[int], int]:
"""If delta is True, only new tokens since the last call to """If delta is True, only new tokens since the last call to
this method are returned""" this method are returned"""
if not delta: if not delta:
return self.get_output_token_ids() return self.get_output_token_ids()
length = self.get_output_len()
last_offset = self._last_token_ids_offset output_len = self.get_output_len()
if last_offset < length:
self._last_token_ids_offset = length # Get the number of new tokens
return self.data._output_token_ids[last_offset:] num_new_tokens = output_len - self._last_output_token_ids_offset
return () self._last_output_token_ids_offset = output_len
# Return new tokens
if num_new_tokens == 1:
# Optimization for single decode token case
# (which is what we have most of the time)
return self.data._cached_all_token_ids[-1]
return self.data._cached_all_token_ids[-num_new_tokens:]
def hash_of_block(self, logical_idx: int) -> int: def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size # TODO This can produce incorrect hash when block size > prompt size
...@@ -623,6 +654,7 @@ class SequenceGroup: ...@@ -623,6 +654,7 @@ class SequenceGroup:
unless you are working with an encoder/decoder model. unless you are working with an encoder/decoder model.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request. prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request.
""" """
def __init__( def __init__(
...@@ -637,9 +669,11 @@ class SequenceGroup: ...@@ -637,9 +669,11 @@ class SequenceGroup:
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs = seqs self.seqs = seqs
self.arrival_time = arrival_time
self.is_single_seq = len(seqs) == 1 self.is_single_seq = len(seqs) == 1
self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.seqs_dict = {seq.seq_id: seq for seq in seqs}
...@@ -657,6 +691,9 @@ class SequenceGroup: ...@@ -657,6 +691,9 @@ class SequenceGroup:
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq self.encoder_seq = encoder_seq
self.trace_headers = trace_headers self.trace_headers = trace_headers
self.priority = priority
self.cached_request_output = None
@property @property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
......
...@@ -6,9 +6,9 @@ import torch ...@@ -6,9 +6,9 @@ import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
SequenceData, SequenceGroupMetadata, ExecuteModelRequest, SequenceData,
get_all_seq_ids) SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
...@@ -69,10 +69,10 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -69,10 +69,10 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
proposal_lens_list = proposals.proposal_lens.tolist() proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist()
# Filter the list to ignore -1 proposals. # Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips = [ proposal_token_ids_list_without_skips = [
proposals for proposals in proposal_token_ids_list proposals for proposals in proposal_token_ids_list
if -1 not in proposals if VLLM_INVALID_TOKEN_ID not in proposals
] ]
(spec_indices, non_spec_indices, target_seq_group_metadata_list, (spec_indices, non_spec_indices, target_seq_group_metadata_list,
......
...@@ -183,10 +183,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -183,10 +183,7 @@ class TP1DraftModelRunner(ModelRunner):
return False return False
# TODO: Add soft-tuning prompt adapter support # TODO: Add soft-tuning prompt adapter support
if self.prompt_adapter_config: return not self.prompt_adapter_config
return False
return True
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
......
...@@ -104,13 +104,10 @@ class AsyncMetricsCollector: ...@@ -104,13 +104,10 @@ class AsyncMetricsCollector:
if self._rank != 0: if self._rank != 0:
return False return False
if (now - self._last_metrics_collect_time < return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501
self._rejsample_metrics_collect_interval_s):
return False
return True
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection/typical-acceptance sampling metrics """Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously. (number of accepted tokens, etc) to CPU asynchronously.
Returns a CUDA event recording when the copy is complete. Returns a CUDA event recording when the copy is complete.
......
...@@ -13,9 +13,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( ...@@ -13,9 +13,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import ( from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler) TypicalAcceptanceSampler)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata, HiddenStates, SequenceGroupMetadata,
get_all_seq_ids, get_all_seq_ids_and_request_ids) get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
...@@ -28,7 +29,8 @@ from vllm.spec_decode.ngram_worker import NGramWorker ...@@ -28,7 +29,8 @@ from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (Timer, create_sequence_group_output, from vllm.spec_decode.util import (Timer, create_logprobs_output,
create_sequence_group_output,
get_all_num_logprobs, get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range, get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len) split_batch_by_proposal_len)
...@@ -164,11 +166,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -164,11 +166,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
spec_decode_sampler: SpecDecodeBaseSampler = None spec_decode_sampler: SpecDecodeBaseSampler = None
if draft_token_acceptance_method == "rejection_sampler": if draft_token_acceptance_method == "rejection_sampler":
spec_decode_sampler = RejectionSampler( spec_decode_sampler = RejectionSampler()
disable_bonus_tokens=False, )
elif draft_token_acceptance_method == "typical_acceptance_sampler": elif draft_token_acceptance_method == "typical_acceptance_sampler":
spec_decode_sampler = TypicalAcceptanceSampler( spec_decode_sampler = TypicalAcceptanceSampler(
disable_bonus_tokens=False,
posterior_threshold=\ posterior_threshold=\
typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_threshold,
posterior_alpha=typical_acceptance_sampler_posterior_alpha, posterior_alpha=typical_acceptance_sampler_posterior_alpha,
...@@ -438,8 +438,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -438,8 +438,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self, execute_model_req: ExecuteModelRequest, self, execute_model_req: ExecuteModelRequest,
sampler_output: SamplerOutput) -> SamplerOutput: sampler_output: SamplerOutput) -> SamplerOutput:
""" """
Creates and returns a `SamplerOutput` with only the sampled token IDs Creates and returns a `SamplerOutput` with only the token IDs being
being serialized to CPU & populated in `CompletionSequenceGroupOutput`. serialized to CPU and populated in `CompletionSequenceGroupOutput`.
All other parameters in `CompletionSequenceGroupOutput` related to log All other parameters in `CompletionSequenceGroupOutput` related to log
probabilities are skipped. probabilities are skipped.
...@@ -451,14 +451,46 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -451,14 +451,46 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
Returns: Returns:
SamplerOutput: A new `SamplerOutput` instance containing a list of SamplerOutput: A new `SamplerOutput` instance containing a list of
`CompletionSequenceGroupOutput` objects with only sampled token `CompletionSequenceGroupOutput` objects with only token IDs
IDs populated. populated.
""" """
seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list) seq_output_prompt_logprobs = [
sampled_token_ids_list = sampler_output.sampled_token_ids.tolist() seq.is_prompt and seq.sampling_params.prompt_logprobs is not None
and seq.sampling_params.prompt_logprobs > 0
for seq in execute_model_req.seq_group_metadata_list
]
# ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where(
# subtracting is faster than testing for equality
sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \
if any(seq_output_prompt_logprobs) else \
sampler_output.sampled_token_ids).tolist()
seq_data_entries = (
(seq_id, seq_data) for sg in \
execute_model_req.seq_group_metadata_list \
for seq_id, seq_data in sg.seq_data.items()
)
completion_seq_group_output_list: List[ completion_seq_group_output_list: List[
CompletionSequenceGroupOutput] = [] CompletionSequenceGroupOutput] = []
for index, seq_id in enumerate(seq_ids): for index, ((seq_id, seq_data), needs_prompt_logprobs) in \
enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)):
if needs_prompt_logprobs:
prompt_token_ids = seq_data.get_prompt_token_ids()
prompt_logprobs = [
create_logprobs_output(
token_id=p_token_id,
token_id_logprob_rank=-1,
token_id_logprob=0.0,
topk_token_ids=[],
topk_logprobs=[],
)
# no prompt logprobs for the first token
for p_token_id in prompt_token_ids[1:]
]
else:
prompt_logprobs = None
completion_seq_group_output_list.append( completion_seq_group_output_list.append(
create_sequence_group_output( create_sequence_group_output(
token_id=sampled_token_ids_list[index][0], token_id=sampled_token_ids_list[index][0],
...@@ -467,7 +499,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -467,7 +499,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
seq_id=seq_id, seq_id=seq_id,
topk_token_ids=[], topk_token_ids=[],
topk_logprobs=[], topk_logprobs=[],
)) prompt_logprobs=prompt_logprobs))
return SamplerOutput(outputs=completion_seq_group_output_list) return SamplerOutput(outputs=completion_seq_group_output_list)
@nvtx_range("spec_decode_worker._run_no_spec") @nvtx_range("spec_decode_worker._run_no_spec")
...@@ -487,6 +519,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -487,6 +519,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Store hidden states from target model execution. # Store hidden states from target model execution.
hidden_states = sampler_output.hidden_states hidden_states = sampler_output.hidden_states
if hidden_states is not None: if hidden_states is not None:
# remove hidden_states for prompt tokens
if any(seq.is_prompt
for seq in execute_model_req.seq_group_metadata_list):
hidden_states = hidden_states[
torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]]
if self.previous_hidden_states is None: if self.previous_hidden_states is None:
self.previous_hidden_states = HiddenStates( self.previous_hidden_states = HiddenStates(
hidden_states, execute_model_req.seq_group_metadata_list) hidden_states, execute_model_req.seq_group_metadata_list)
......
...@@ -6,7 +6,8 @@ import torch ...@@ -6,7 +6,8 @@ import torch
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceGroupMetadata, SequenceOutput) PromptLogprobs, SequenceGroupMetadata,
SequenceOutput)
SeqId = int SeqId = int
...@@ -49,21 +50,19 @@ def get_sampled_token_logprobs( ...@@ -49,21 +50,19 @@ def get_sampled_token_logprobs(
return sampled_token_ids_ranks, selected_logprobs return sampled_token_ids_ranks, selected_logprobs
def create_sequence_group_output( def create_logprobs_output(
token_id: int, token_id: int,
token_id_logprob_rank: int, token_id_logprob_rank: int,
token_id_logprob: float, token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[Optional[int]], topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]], topk_logprobs: List[Optional[float]],
) -> CompletionSequenceGroupOutput: ) -> Dict[int, Logprob]:
"""Create a SequenceGroupOutput given the sampling results. """Create a Logprob Dict for a token given the sampling results.
Args: Args:
token_id (int): The sampled token for the sequence. token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token. token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token. token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[Optional[int]]): The list of top-k token ids. topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs. topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
""" """
...@@ -85,14 +84,44 @@ def create_sequence_group_output( ...@@ -85,14 +84,44 @@ def create_sequence_group_output(
if topk_token_id is not None if topk_token_id is not None
}) })
return logprobs
def create_sequence_group_output(
token_id: int,
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]],
prompt_logprobs: Optional[PromptLogprobs] = None,
) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
"""
logprobs = create_logprobs_output(
token_id,
token_id_logprob_rank,
token_id_logprob,
topk_token_ids,
topk_logprobs,
)
return CompletionSequenceGroupOutput( return CompletionSequenceGroupOutput(
samples=[ samples=[
SequenceOutput(parent_seq_id=seq_id, SequenceOutput(parent_seq_id=seq_id,
output_token=token_id, output_token=token_id,
logprobs=logprobs) logprobs=logprobs)
], ],
# TODO add prompt logprobs support. prompt_logprobs=prompt_logprobs,
prompt_logprobs=None,
) )
......
...@@ -22,8 +22,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, ...@@ -22,8 +22,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
EAGLEConfig, ExaoneConfig, EAGLEConfig, ExaoneConfig,
GraniteConfig, InternVLChatConfig, GraniteConfig, InternVLChatConfig,
JAISConfig, MedusaConfig, JAISConfig, MedusaConfig,
MLPSpeculatorConfig, MPTConfig, MllamaConfig, MLPSpeculatorConfig,
NemotronConfig, RWConfig, MPTConfig, NemotronConfig,
RWConfig, SolarConfig,
UltravoxConfig) UltravoxConfig)
# yapf: enable # yapf: enable
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
...@@ -37,6 +38,10 @@ MISTRAL_CONFIG_NAME = "params.json" ...@@ -37,6 +38,10 @@ MISTRAL_CONFIG_NAME = "params.json"
logger = init_logger(__name__) logger = init_logger(__name__)
_CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = {
"mllama": MllamaConfig
}
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"dbrx": DbrxConfig, "dbrx": DbrxConfig,
...@@ -50,15 +55,20 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ...@@ -50,15 +55,20 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"exaone": ExaoneConfig, "exaone": ExaoneConfig,
"internvl_chat": InternVLChatConfig, "internvl_chat": InternVLChatConfig,
"nemotron": NemotronConfig, "nemotron": NemotronConfig,
"solar": SolarConfig,
"ultravox": UltravoxConfig, "ultravox": UltravoxConfig,
# Granite can be removed from here once we have upgraded to # Granite can be removed from here once we have upgraded to
# transformers 4.45+ # transformers 4.45+
"granite": GraniteConfig, "granite": GraniteConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
} }
for name, cls in _CONFIG_REGISTRY.items(): for name, cls in _CONFIG_REGISTRY.items():
with contextlib.suppress(ValueError): with contextlib.suppress(ValueError):
AutoConfig.register(name, cls) if name in _CONFIG_REGISTRY_OVERRIDE_HF:
AutoConfig.register(name, cls, exist_ok=True)
else:
AutoConfig.register(name, cls)
class ConfigFormat(str, enum.Enum): class ConfigFormat(str, enum.Enum):
......
...@@ -10,9 +10,11 @@ from vllm.transformers_utils.configs.granite import GraniteConfig ...@@ -10,9 +10,11 @@ from vllm.transformers_utils.configs.granite import GraniteConfig
from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig
from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm.transformers_utils.configs.mllama import MllamaConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [ __all__ = [
...@@ -25,8 +27,10 @@ __all__ = [ ...@@ -25,8 +27,10 @@ __all__ = [
"MedusaConfig", "MedusaConfig",
"EAGLEConfig", "EAGLEConfig",
"ExaoneConfig", "ExaoneConfig",
"MllamaConfig",
"MLPSpeculatorConfig", "MLPSpeculatorConfig",
"NemotronConfig", "NemotronConfig",
"SolarConfig",
"UltravoxConfig", "UltravoxConfig",
# Granite can be removed from here once we have upgraded to # Granite can be removed from here once we have upgraded to
# transformers 4.45+ # transformers 4.45+
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment