Commit 081057de authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-ori

parents 7cf5d5c4 ba41cc90
......@@ -40,7 +40,6 @@ from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of)
......@@ -118,7 +117,7 @@ class LLM:
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
......@@ -252,11 +251,15 @@ class LLM:
self.request_counter = Counter()
self.default_sampling_params: Union[dict[str, Any], None] = None
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
lora_request)
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
tokenizer_group = self.llm_engine.get_tokenizer_group()
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
......@@ -520,11 +523,9 @@ class LLM:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
params: The beam search parameters.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""
# TODO: how does beam search work together with length penalty,
# frequency, penalty, and stopping criteria, etc.?
beam_width = params.beam_width
max_tokens = params.max_tokens
temperature = params.temperature
......@@ -536,15 +537,18 @@ class LLM:
tokenizer.eos_token_id,
length_penalty)
# TODO - fix handling of multimodal data for beam search; we pass it
# through in the async version on the abstract EngineClient, but not
# here.
if any("multi_modal_data" in prompt
and prompt["multi_modal_data"] is not None
for prompt in prompts):
logger.warning(
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!")
def create_tokens_prompt_from_beam(
beam: BeamSearchSequence) -> TokensPrompt:
token_prompt_kwargs: TokensPrompt = {
"prompt_token_ids": beam.tokens
}
if beam.multi_modal_data is not None:
token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data
if beam.mm_processor_kwargs is not None:
token_prompt_kwargs[
"mm_processor_kwargs"] = beam.mm_processor_kwargs
return TokensPrompt(**token_prompt_kwargs)
tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
......@@ -556,11 +560,20 @@ class LLM:
instances: list[BeamSearchInstance] = []
for prompt in prompts:
# Add multimodal processor kwargs & data
mm_kwargs = {}
if "multi_modal_data" in prompt:
mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
if "mm_processor_kwargs" in prompt:
mm_kwargs["mm_processor_kwargs"] = prompt[
"mm_processor_kwargs"]
if is_token_prompt(prompt):
prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append(BeamSearchInstance(prompt_tokens))
instances.append(
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
for _ in range(max_tokens):
all_beams: list[BeamSearchSequence] = list(
......@@ -575,8 +588,7 @@ class LLM:
break
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
create_tokens_prompt_from_beam(beam) for beam in all_beams
]
# only runs for one step
......@@ -602,7 +614,10 @@ class LLM:
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs)
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
......@@ -701,7 +716,7 @@ class LLM:
cast(list[ChatCompletionMessageParam], messages)
]
tokenizer = self.get_tokenizer()
tokenizer = self.get_tokenizer(lora_request)
model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format(
chat_template,
......@@ -724,9 +739,8 @@ class LLM:
content_format=resolved_content_format,
)
prompt_data: Union[str, list[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt_data = apply_mistral_chat_template(
prompt_token_ids = apply_mistral_chat_template(
tokenizer,
messages=msgs,
chat_template=chat_template,
......@@ -735,7 +749,7 @@ class LLM:
continue_final_message=continue_final_message,
)
else:
prompt_data = apply_hf_chat_template(
prompt_str = apply_hf_chat_template(
tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation,
......@@ -744,12 +758,12 @@ class LLM:
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
)
# Special tokens are already included in chat templates so
# should not be added by the tokenizer in this case.
prompt_token_ids = tokenizer.encode(prompt_str,
add_special_tokens=False)
prompt: Union[TokensPrompt, TextPrompt]
if is_list_of(prompt_data, int):
prompt = TokensPrompt(prompt_token_ids=prompt_data)
else:
prompt = TextPrompt(prompt=prompt_data)
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
......@@ -1048,8 +1062,6 @@ class LLM:
if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
scores: list[PoolingRequestOutput] = []
scores = _cosine_similarity(tokenizer=tokenizer,
embed_1=encoded_output_1,
embed_2=encoded_output_2)
......@@ -1384,7 +1396,9 @@ class LLM:
grammar=guided_options.guided_grammar,
json_object=guided_options.guided_json_object,
backend=guided_options.guided_decoding_backend,
whitespace_pattern=guided_options.guided_whitespace_pattern)
whitespace_pattern=guided_options.guided_whitespace_pattern,
structural_tag=guided_options.structural_tag,
)
return params
def _run_engine(
......
......@@ -30,7 +30,7 @@ from starlette.routing import Mount
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
from vllm.engine.multiprocessing.client import MQLLMEngineClient
......@@ -310,32 +310,33 @@ def mount_metrics(app: FastAPI):
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app,
multiprocess)
from prometheus_fastapi_instrumentator import Instrumentator
registry = REGISTRY
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
Instrumentator(
excluded_handlers=[
"/metrics",
"/health",
"/load",
"/ping",
"/version",
],
registry=registry,
).add().instrument(app).expose(app)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
else:
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
Instrumentator(
excluded_handlers=[
"/metrics",
"/health",
"/load",
"/ping",
"/version",
"/server_info",
],
registry=registry,
).add().instrument(app).expose(app)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
......@@ -687,6 +688,11 @@ TASK_HANDLERS: dict[str, dict[str, tuple]] = {
if envs.VLLM_SERVER_DEV_MODE:
@router.get("/server_info")
async def show_server_info(raw_request: Request):
server_info = {"vllm_config": str(raw_request.app.state.vllm_config)}
return JSONResponse(content=server_info)
@router.post("/reset_prefix_cache")
async def reset_prefix_cache(raw_request: Request):
"""
......@@ -875,7 +881,8 @@ def build_app(args: Namespace) -> FastAPI:
section async for section in response.body_iterator
]
response.body_iterator = iterate_in_threadpool(iter(response_body))
logger.info("response_body={%s}", response_body[0].decode())
logger.info("response_body={%s}",
response_body[0].decode() if response_body else None)
return response
for middleware in args.middleware:
......@@ -894,7 +901,7 @@ def build_app(args: Namespace) -> FastAPI:
async def init_app_state(
engine_client: EngineClient,
model_config: ModelConfig,
vllm_config: VllmConfig,
state: State,
args: Namespace,
) -> None:
......@@ -915,6 +922,8 @@ async def init_app_state(
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
state.vllm_config = vllm_config
model_config = vllm_config.model_config
resolved_chat_template = load_chat_template(args.chat_template)
if resolved_chat_template is not None:
......@@ -1069,8 +1078,8 @@ async def run_server(args, **uvicorn_kwargs) -> None:
async with build_async_engine_client(args) as engine_client:
app = build_app(args)
model_config = await engine_client.get_model_config()
await init_app_state(engine_client, model_config, app.state, args)
vllm_config = await engine_client.get_vllm_config()
await init_app_state(engine_client, vllm_config, app.state, args)
def _listen_addr(a: str) -> str:
if is_valid_ipv6_address(a):
......
......@@ -11,7 +11,7 @@ import ssl
from collections.abc import Sequence
from typing import Optional, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
......@@ -79,7 +79,7 @@ class PromptAdapterParserAction(argparse.Action):
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host",
type=nullable_str,
type=optional_type(str),
default=None,
help="Host name.")
parser.add_argument("--port", type=int, default=8000, help="Port number.")
......@@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=["*"],
help="Allowed headers.")
parser.add_argument("--api-key",
type=nullable_str,
type=optional_type(str),
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument(
"--lora-modules",
type=nullable_str,
type=optional_type(str),
default=None,
nargs='+',
action=LoRAParserAction,
......@@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"\"base_model_name\": \"id\"}``")
parser.add_argument(
"--prompt-adapters",
type=nullable_str,
type=optional_type(str),
default=None,
nargs='+',
action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.")
parser.add_argument("--chat-template",
type=nullable_str,
type=optional_type(str),
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
......@@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'similar to OpenAI schema. '
'Example: ``[{"type": "text", "text": "Hello world!"}]``')
parser.add_argument("--response-role",
type=nullable_str,
type=optional_type(str),
default="assistant",
help="The role name to return if "
"``request.add_generation_prompt=true``.")
parser.add_argument("--ssl-keyfile",
type=nullable_str,
type=optional_type(str),
default=None,
help="The file path to the SSL key file.")
parser.add_argument("--ssl-certfile",
type=nullable_str,
type=optional_type(str),
default=None,
help="The file path to the SSL cert file.")
parser.add_argument("--ssl-ca-certs",
type=nullable_str,
type=optional_type(str),
default=None,
help="The CA certificates file.")
parser.add_argument(
......@@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
)
parser.add_argument(
"--root-path",
type=nullable_str,
type=optional_type(str),
default=None,
help="FastAPI root_path when app is behind a path based routing proxy."
)
parser.add_argument(
"--middleware",
type=nullable_str,
type=optional_type(str),
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
......
......@@ -2,6 +2,7 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import json
import re
import time
from argparse import Namespace
......@@ -139,12 +140,30 @@ class JsonSchemaResponseFormat(OpenAIBaseModel):
strict: Optional[bool] = None
class StructuralTag(OpenAIBaseModel):
begin: str
# schema is the field, but that causes conflicts with pydantic so
# instead use structural_tag_schema with an alias
structural_tag_schema: Optional[dict[str, Any]] = Field(default=None,
alias="schema")
end: str
class StructuralTagResponseFormat(OpenAIBaseModel):
type: Literal["structural_tag"]
structures: list[StructuralTag]
triggers: list[str]
class ResponseFormat(OpenAIBaseModel):
# type must be "json_schema", "json_object" or "text"
# type must be "json_schema", "json_object", or "text"
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None
AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat]
class StreamOptions(OpenAIBaseModel):
include_usage: Optional[bool] = True
continuous_usage_stats: Optional[bool] = False
......@@ -227,7 +246,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
max_completion_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
response_format: Optional[AnyResponseFormat] = None
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, list[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
......@@ -340,6 +359,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, the output will follow the context free grammar."),
)
structural_tag: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the structural tag schema."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
......@@ -476,6 +500,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema
elif self.response_format.type == "structural_tag":
structural_tag = self.response_format
assert structural_tag is not None and isinstance(
structural_tag, StructuralTagResponseFormat)
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structural_tag = json.dumps(s_tag_obj)
guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
......@@ -485,6 +515,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern,
structural_tag=self.structural_tag,
)
return SamplingParams.from_optional(
......@@ -742,12 +773,13 @@ class CompletionRequest(OpenAIBaseModel):
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
response_format: Optional[ResponseFormat] = Field(
response_format: Optional[AnyResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
"{'type': 'text' } is supported."),
description=(
"Similar to chat completion, this parameter specifies the format "
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
......@@ -1577,14 +1609,6 @@ class TranscriptionRequest(OpenAIBaseModel):
"""
## TODO (varun) : Support if set to 0, certain thresholds are met !!
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
timestamp_granularities: list[Literal["word", "segment"]] = Field(
alias="timestamp_granularities[]", default=[])
......@@ -1596,6 +1620,7 @@ class TranscriptionRequest(OpenAIBaseModel):
timestamps incurs additional latency.
"""
# doc: begin-transcription-extra-params
stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
......@@ -1604,10 +1629,51 @@ class TranscriptionRequest(OpenAIBaseModel):
# Flattened stream option to simplify form data.
stream_include_usage: Optional[bool] = False
stream_continuous_usage_stats: Optional[bool] = False
# doc: end-transcription-extra-params
# doc: begin-transcription-sampling-params
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
top_p: Optional[float] = None
"""Enables nucleus (top-p) sampling, where tokens are selected from the
smallest possible set whose cumulative probability exceeds `p`.
"""
top_k: Optional[int] = None
"""Limits sampling to the `k` most probable tokens at each step."""
min_p: Optional[float] = None
"""Filters out tokens with a probability lower than `min_p`, ensuring a
minimum likelihood threshold during sampling.
"""
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
"""The seed to use for sampling."""
frequency_penalty: Optional[float] = 0.0
"""The frequency penalty to use for sampling."""
repetition_penalty: Optional[float] = None
"""The repetition penalty to use for sampling."""
presence_penalty: Optional[float] = 0.0
"""The presence penalty to use for sampling."""
# doc: end-transcription-sampling-params
# Default sampling parameters for transcription requests.
_DEFAULT_SAMPLING_PARAMS: dict = {
"temperature": 0,
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
def to_sampling_params(
......@@ -1619,13 +1685,35 @@ class TranscriptionRequest(OpenAIBaseModel):
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"])
return SamplingParams.from_optional(temperature=temperature,
max_tokens=max_tokens,
seed=self.seed,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
presence_penalty=self.presence_penalty,
output_kind=RequestOutputKind.DELTA
if self.stream \
else RequestOutputKind.FINAL_ONLY)
......
......@@ -12,7 +12,7 @@ import torch
from prometheus_client import start_http_server
from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger, logger
# yapf: disable
......@@ -61,7 +61,7 @@ def parse_args():
"to the output URL.",
)
parser.add_argument("--response-role",
type=nullable_str,
type=optional_type(str),
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=True`.")
......
......@@ -10,6 +10,7 @@ from fastapi import Request
from pydantic import Field
from starlette.datastructures import Headers
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
# yapf conflicts with isort for this block
......@@ -125,18 +126,29 @@ class OpenAIServing:
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
error_response = None
if self._is_model_supported(request.model):
return None
if request.model in [
lora.lora_name for lora in self.models.lora_requests
]:
return None
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
load_result := await self.models.resolve_lora(request.model)):
if isinstance(load_result, LoRARequest):
return None
if isinstance(load_result, ErrorResponse) and \
load_result.code == HTTPStatus.BAD_REQUEST.value:
error_response = load_result
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.models.prompt_adapter_requests
]:
return None
return self.create_error_response(
return error_response or self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
......
......@@ -2,6 +2,8 @@
import json
import pathlib
from asyncio import Lock
from collections import defaultdict
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional, Union
......@@ -15,6 +17,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
UnloadLoRAAdapterRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import AtomicCounter
......@@ -63,11 +66,19 @@ class OpenAIServingModels:
self.base_model_paths = base_model_paths
self.max_model_len = model_config.max_model_len
self.engine_client = engine_client
self.model_config = model_config
self.static_lora_modules = lora_modules
self.lora_requests: list[LoRARequest] = []
self.lora_id_counter = AtomicCounter(0)
self.lora_resolvers: list[LoRAResolver] = []
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(
):
self.lora_resolvers.append(
LoRAResolverRegistry.get_resolver(lora_resolver_name))
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
......@@ -234,6 +245,65 @@ class OpenAIServingModels:
return None
async def resolve_lora(
self, lora_name: str) -> Union[LoRARequest, ErrorResponse]:
"""Attempt to resolve a LoRA adapter using available resolvers.
Args:
lora_name: Name/identifier of the LoRA adapter
Returns:
LoRARequest if found and loaded successfully.
ErrorResponse (404) if no resolver finds the adapter.
ErrorResponse (400) if adapter(s) are found but none load.
"""
async with self.lora_resolver_lock[lora_name]:
# First check if this LoRA is already loaded
for existing in self.lora_requests:
if existing.lora_name == lora_name:
return existing
base_model_name = self.model_config.model
unique_id = self.lora_id_counter.inc(1)
found_adapter = False
# Try to resolve using available resolvers
for resolver in self.lora_resolvers:
lora_request = await resolver.resolve_lora(
base_model_name, lora_name)
if lora_request is not None:
found_adapter = True
lora_request.lora_int_id = unique_id
try:
await self.engine_client.add_lora(lora_request)
self.lora_requests.append(lora_request)
logger.info(
"Resolved and loaded LoRA adapter '%s' using %s",
lora_name, resolver.__class__.__name__)
return lora_request
except BaseException as e:
logger.warning(
"Failed to load LoRA '%s' resolved by %s: %s. "
"Trying next resolver.", lora_name,
resolver.__class__.__name__, e)
continue
if found_adapter:
# An adapter was found, but all attempts to load it failed.
return create_error_response(
message=(f"LoRA adapter '{lora_name}' was found "
"but could not be loaded."),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST)
else:
# No adapter was found
return create_error_response(
message=f"LoRA adapter {lora_name} does not exist",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def create_error_response(
message: str,
......
......@@ -27,6 +27,7 @@ logger = init_logger(__name__)
@ToolParserManager.register_module("llama3_json")
@ToolParserManager.register_module("llama4_json")
class Llama3JsonToolParser(ToolParser):
"""
Tool call parser for Llama 3.1 models intended for use with the
......
......@@ -38,6 +38,10 @@ class MistralToolCall(ToolCall):
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))
@staticmethod
def is_valid_id(id: str) -> bool:
return id.isalnum() and len(id) == 9
@ToolParserManager.register_module("mistral")
class MistralToolParser(ToolParser):
......@@ -70,6 +74,19 @@ class MistralToolParser(ToolParser):
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!")
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if not isinstance(
self.model_tokenizer, MistralTokenizer
) and request.tools and request.tool_choice != 'none':
# Do not skip special tokens when using chat template
# with Mistral parser as TOOL_CALL token is needed
# for tool detection.
# Note: we don't want skip_special_tokens=False
# with MistralTokenizer as it is incompatible
request.skip_special_tokens = False
return request
def extract_tool_calls(
self,
model_output: str,
......
......@@ -8,8 +8,21 @@ import torch
# that interact with vllm workers.
# they are executed whenever `import vllm` is called.
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
if not os.path.exists('/dev/nvidia-caps-imex-channels'):
# normally, we disable NCCL_CUMEM_ENABLE because it
# will cost 1~2 GiB GPU memory with cudagraph+allreduce,
# see https://github.com/NVIDIA/nccl/issues/1234
# for more details.
# However, NCCL requires NCCL_CUMEM_ENABLE to work with
# multi-node NVLink, typically on GB200-NVL72 systems.
# The ultimate way to detect multi-node NVLink is to use
# NVML APIs, which are too expensive to call here.
# As an approximation, we check the existence of
# /dev/nvidia-caps-imex-channels, used by
# multi-node NVLink to communicate across nodes.
# This will still cost some GPU memory, but it is worthwhile
# because we can get very fast cross-node bandwidth with NVLink.
os.environ['NCCL_CUMEM_ENABLE'] = '0'
# see https://github.com/vllm-project/vllm/pull/15951
# it avoids unintentional cuda initialization from torch.cuda.is_available()
......
......@@ -75,10 +75,12 @@ if TYPE_CHECKING:
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
......@@ -96,6 +98,7 @@ if TYPE_CHECKING:
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
VLLM_HPU_USE_DELAYED_SAMPLING: bool = False
VLLM_DP_RANK: int = 0
VLLM_DP_RANK_LOCAL: int = -1
VLLM_DP_SIZE: int = 1
......@@ -103,10 +106,10 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
def get_default_cache_root():
......@@ -533,6 +536,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")),
# Whether to use aiter paged attention.
# By default is disabled.
"VLLM_ROCM_USE_AITER_PAGED_ATTN":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in
("true", "1")),
# use aiter linear op if aiter ops are enabled
# The following list of related ops
# - scaled_mm (per-tensor / rowwise)
......@@ -546,18 +555,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
("true", "1")),
# Whether to use aiter block scaled moe kernel.
# By default this is disabled.
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
lambda:
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
("true", "1")),
# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
("true", "1")),
# Whether to use aiter mla ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MLA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in
("true", "1")),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
("true", "1")),
# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
......@@ -639,6 +651,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in
("1", "true"),
# Use delayed sampling for HPU to reduce host cpu overhead
# between each step.
"VLLM_HPU_USE_DELAYED_SAMPLING":
lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in
("1", "true"),
# Rank of the process in the data parallel setting
"VLLM_DP_RANK":
lambda: int(os.getenv("VLLM_DP_RANK", "0")),
......@@ -684,11 +702,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
# Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP":
......@@ -704,6 +717,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
# It can be changed with this variable if needed for some reason.
"VLLM_XGRAMMAR_CACHE_MB":
lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")),
# Control the threshold for msgspec to use 'zero copy' for
# serialization/deserialization of tensors. Tensors below
# this limit will be encoded into the msgpack buffer, and
# tensors above will instead be sent via a separate message.
# While the sending side still actually copies the tensor
# in all cases, on the receiving side, tensors above this
# limit will actually be zero-copy decoded.
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD":
lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")),
}
# end-env-vars-definition
......@@ -742,7 +765,7 @@ def compute_hash() -> str:
variables, ensure that it is included in the factors list if
it affects the computation graph. For example, different values
of VLLM_PP_LAYER_PARTITION will generate different computation
graphs, so it is included in the factors list. The env vars that
graphs, so it is included in the factors list. The env vars that
affect the choice of different kernels or attention backends should
also be included in the factors list.
"""
......@@ -771,6 +794,7 @@ def compute_hash() -> str:
if key in environment_variables:
factorize(key)
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
......@@ -34,13 +34,13 @@ class UniProcExecutor(ExecutorBase):
if len(device_info) > 1:
local_rank = int(device_info[1])
rank = 0
is_driver_worker = True
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
is_driver_worker=is_driver_worker,
)
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
......
......@@ -11,6 +11,10 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.logger import init_logger
if TYPE_CHECKING:
......@@ -98,6 +102,17 @@ def set_forward_context(attn_metadata: Any,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata)
# KVConnector: trigger (possibly async) load before forward.
# Each attn layer will block until the reading is complete.
trigger_kv_transfer = (attn_metadata is not None
and has_kv_transfer_group()
and is_v1_kv_transfer_group())
if trigger_kv_transfer:
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.start_load_kv(_forward_context)
try:
yield
finally:
......@@ -133,4 +148,12 @@ def set_forward_context(attn_metadata: Any,
logger.info(("Batchsize forward time stats "
"(batchsize, count, median_time(ms)): %s"),
forward_stats)
# KVConnector: each attn layer triggers (possibly async) save.
# Ensure all those operations complete before forward() is done.
if trigger_kv_transfer:
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.wait_for_save()
_forward_context = prev_context
......@@ -2,10 +2,9 @@
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonInputsAdapter, SingletonPrompt,
TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts)
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import (DummyData, InputContext, InputProcessingContext,
InputRegistry)
......@@ -27,7 +26,6 @@ __all__ = [
"EncoderDecoderInputs",
"ProcessorInputs",
"SingletonInputs",
"SingletonInputsAdapter",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
import torch
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
from typing_extensions import NotRequired, TypedDict, TypeVar
if TYPE_CHECKING:
from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs,
MultiModalPlaceholderDict)
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
class TextPrompt(TypedDict):
......@@ -147,46 +141,11 @@ class TokenInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
multi_modal_inputs: NotRequired["MultiModalKwargs"]
"""
Optional multi-modal inputs to pass to the model,
if the model supports it.
"""
multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
"""
Placeholder ranges for the multi-modal data.
"""
multi_modal_hashes: NotRequired[list[str]]
"""
The hashes of the multi-modal data.
"""
mm_processor_kwargs: NotRequired[dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
def token_inputs(
prompt_token_ids: list[int],
token_type_ids: Optional[list[int]] = None,
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
multi_modal_hashes: Optional[list[str]] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values."""
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
......@@ -195,16 +154,6 @@ def token_inputs(
inputs["prompt"] = prompt
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data
if multi_modal_inputs is not None:
inputs["multi_modal_inputs"] = multi_modal_inputs
if multi_modal_hashes is not None:
inputs["multi_modal_hashes"] = multi_modal_hashes
if multi_modal_placeholders is not None:
inputs["multi_modal_placeholders"] = multi_modal_placeholders
if mm_processor_kwargs is not None:
inputs["mm_processor_kwargs"] = mm_processor_kwargs
return inputs
......@@ -237,112 +186,6 @@ A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""
@dataclass
class SingletonInputsAdapter:
"""
Unified interface to access the components of :class:`SingletonInputs`.
"""
inputs: SingletonInputs
@cached_property
def prompt(self) -> Optional[str]:
inputs = self.inputs
if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("prompt")
assert_never(inputs) # type: ignore[arg-type]
@cached_property
def prompt_token_ids(self) -> list[int]:
inputs = self.inputs
if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("prompt_token_ids", [])
assert_never(inputs) # type: ignore[arg-type]
@cached_property
def token_type_ids(self) -> list[int]:
inputs = self.inputs
if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("token_type_ids", [])
assert_never(inputs) # type: ignore[arg-type]
@cached_property
def prompt_embeds(self) -> Optional[torch.Tensor]:
inputs = self.inputs
if inputs["type"] == "token" or inputs["type"] == "multimodal":
return None
assert_never(inputs) # type: ignore[arg-type]
@cached_property
def multi_modal_data(self) -> "MultiModalDataDict":
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("multi_modal_data", {})
if inputs["type"] == "multimodal":
return inputs.get("mm_kwargs", {})
assert_never(inputs) # type: ignore[arg-type]
@cached_property
def multi_modal_inputs(self) -> Union[dict, "MultiModalKwargs"]:
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("multi_modal_inputs", {})
if inputs["type"] == "multimodal":
return inputs.get("mm_kwargs", {})
assert_never(inputs) # type: ignore[arg-type]
@cached_property
def multi_modal_hashes(self) -> list[str]:
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("multi_modal_hashes", [])
if inputs["type"] == "multimodal":
# only the case when we use MultiModalInputs
return inputs.get("mm_hashes", []) # type: ignore[return-value]
assert_never(inputs) # type: ignore[arg-type]
@cached_property
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("multi_modal_placeholders", {})
if inputs["type"] == "multimodal":
return inputs.get("mm_placeholders", {})
assert_never(inputs) # type: ignore[arg-type]
@cached_property
def mm_processor_kwargs(self) -> dict[str, Any]:
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("mm_processor_kwargs", {})
if inputs["type"] == "multimodal":
return {}
assert_never(inputs) # type: ignore[arg-type]
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
The inputs to :data:`vllm.inputs.InputProcessor`.
......
......@@ -13,7 +13,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
......@@ -27,7 +27,7 @@ class InputPreprocessor:
def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[BaseTokenizerGroup],
tokenizer: Optional[TokenizerGroup],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None:
super().__init__()
......@@ -36,7 +36,7 @@ class InputPreprocessor:
self.tokenizer = tokenizer
self.mm_registry = mm_registry
def get_tokenizer_group(self) -> BaseTokenizerGroup:
def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None:
raise ValueError("You cannot pass text prompts when "
"`skip_tokenizer_init` is True")
......@@ -223,28 +223,6 @@ class InputPreprocessor:
lora_request=lora_request,
add_special_tokens=add_special_tokens)
def _can_process_multimodal(self) -> bool:
model_config = self.model_config
if not model_config.is_multimodal_model:
raise ValueError("Your model does not support multi-modal inputs")
# Interim measure so we can handle models that have yet to be
# updated to use the new multi-modal processor
can_process_multimodal = self.mm_registry.has_processor(model_config)
if not can_process_multimodal:
from vllm.model_executor.models.registry import _VLLM_MODELS
if not any(arch in _VLLM_MODELS
for arch in model_config.architectures):
logger.warning_once(
"Your model uses the legacy input pipeline, which will be "
"removed in an upcoming release. "
"Please upgrade to the new multi-modal processing pipeline "
"(https://docs.vllm.ai/en/latest/design/mm_processing.html)"
)
return can_process_multimodal
def _process_multimodal(
self,
prompt: Union[str, list[int]],
......@@ -258,8 +236,7 @@ class InputPreprocessor:
returning the corresponding token IDs and metadata.
"""
# At the moment on model (PrithviGeoSpatialMAE) requires to be
# initialized without a tokenizer while using also multi-modal
# input.
# initialized without a tokenizer while using also multi-modal input
if not self.tokenizer:
tokenizer = object() # Dummy
else:
......@@ -285,8 +262,7 @@ class InputPreprocessor:
) -> MultiModalInputs:
"""Async version of :meth:`_process_multimodal`."""
# At the moment on model (PrithviGeoSpatialMAE) requires to be
# initialized without a tokenizer while using also multi-modal
# input.
# initialized without a tokenizer while using also multi-modal input
if not self.tokenizer:
tokenizer = object() # Dummy
else:
......@@ -343,7 +319,7 @@ class InputPreprocessor:
multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
if multi_modal_data is not None and self._can_process_multimodal():
if multi_modal_data is not None:
return self._process_multimodal(
prompt_token_ids,
multi_modal_data,
......@@ -355,8 +331,6 @@ class InputPreprocessor:
return token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
if parsed["type"] == "text":
......@@ -366,7 +340,7 @@ class InputPreprocessor:
multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
if multi_modal_data is not None and self._can_process_multimodal():
if multi_modal_data is not None:
return self._process_multimodal(
prompt_text,
multi_modal_data,
......@@ -383,8 +357,6 @@ class InputPreprocessor:
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
assert_never(parsed)
......@@ -417,7 +389,7 @@ class InputPreprocessor:
multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
if multi_modal_data is not None and self._can_process_multimodal():
if multi_modal_data is not None:
return await self._process_multimodal_async(
prompt_token_ids,
multi_modal_data,
......@@ -426,11 +398,7 @@ class InputPreprocessor:
return_mm_hashes=return_mm_hashes,
)
return token_inputs(
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
return token_inputs(prompt_token_ids=prompt_token_ids)
if parsed["type"] == "text":
text_content = parsed["content"]
......@@ -439,7 +407,7 @@ class InputPreprocessor:
multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
if multi_modal_data is not None and self._can_process_multimodal():
if multi_modal_data is not None:
return await self._process_multimodal_async(
prompt_text,
multi_modal_data,
......@@ -456,8 +424,6 @@ class InputPreprocessor:
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
assert_never(parsed)
......@@ -594,15 +560,13 @@ class InputPreprocessor:
decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
if self.model_config.is_multimodal_model:
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
inputs = self._prompt_to_llm_inputs(prompt)
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
......@@ -637,15 +601,13 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
if self.model_config.is_multimodal_model:
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
inputs = await self._prompt_to_llm_inputs_async(prompt)
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
......
# SPDX-License-Identifier: Apache-2.0
import functools
from collections import UserDict
from collections.abc import Mapping
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, NamedTuple, Optional,
Protocol, Union)
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
from torch import nn
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar, assert_never
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs
from .parse import split_enc_dec_inputs
from vllm.utils import resolve_mm_processor_kwargs
if TYPE_CHECKING:
from vllm.config import ModelConfig
......@@ -26,8 +16,6 @@ if TYPE_CHECKING:
MultiModalRegistry)
from vllm.sequence import SequenceData
logger = init_logger(__name__)
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
......@@ -172,142 +160,23 @@ class InputProcessingContext(InputContext):
raise RuntimeError(msg) from exc
N = TypeVar("N", bound=type[nn.Module])
class DummyData(NamedTuple):
"""Dummy data used for profiling."""
"""
Dummy data used for profiling.
Note: This is only used in V0.
"""
seq_data: "SequenceData"
multi_modal_data: Optional["MultiModalDataDict"] = None
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None
class DummyDataFactory(Protocol):
def __call__(
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
**mm_processor_kwargs: Any,
) -> DummyData:
"""
Create dummy data to be inputted into the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
The :code:`mm_processor_kwargs` are overrides provided at
initialization time to values in the config whose values
may affect the number of tokens per instance.
"""
...
class _MultiModalCounts(UserDict[str, int]):
"""
Wraps `mm_counts` for a more informative error message
when attempting to access a plugin that does not exist.
"""
def __getitem__(self, key: str) -> int:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"There is no multi-modal plugin with the key: {key}. "
f"Available keys: {set(self.keys())}")
raise KeyError(msg) from exc
InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs]
"""Preprocess the inputs to the model."""
class InputRegistry:
"""
A registry to dispatch data processing
according to the target model.
Note: This is only used in V0.
"""
def __init__(self) -> None:
self._dummy_factories_by_model_type = \
ClassRegistry[nn.Module, DummyDataFactory]()
self._dummy_encoder_factories_by_model_type = \
ClassRegistry[nn.Module, DummyDataFactory]()
self._input_processors_by_model_type = \
ClassRegistry[nn.Module, InputProcessor]()
def _default_dummy_data_factory(
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> DummyData:
"""
The default dummy data factory represents the longest possible text
that can be inputted to the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
"""
# Avoid circular import
from vllm.sequence import SequenceData
return DummyData(SequenceData.from_prompt_token_counts((0, seq_len)))
def register_dummy_data(self, factory: DummyDataFactory):
"""
Register a dummy data factory to a model class.
During memory profiling, the provided function is invoked to create
dummy data to be inputted into the model. The resulting memory usage
should be an upper bound of what the model would use at inference time.
"""
def wrapper(model_cls: N) -> N:
if self._dummy_factories_by_model_type.contains(model_cls,
strict=True):
logger.warning(
"Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def _get_dummy_data_factory(self, model_cls: type[nn.Module]):
return self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
def register_dummy_encoder_data(self, factory: DummyDataFactory):
"""
Register a dummy encoder data factory to a model class
This is similar to :meth:`~register_dummy_data`, but for encoder input.
"""
def wrapper(model_cls: N) -> N:
if self._dummy_encoder_factories_by_model_type.contains(
model_cls, strict=True):
logger.warning(
"Model class %s already has dummy encoder data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_encoder_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]):
return self._dummy_encoder_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
def dummy_data_for_profiling(
self,
model_config: "ModelConfig",
......@@ -319,169 +188,25 @@ class InputRegistry:
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
Note:
This should be called after
:meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.sequence import SequenceData
if mm_registry.has_processor(model_config):
processor = mm_registry.create_processor(model_config,
disable_cache=True)
profiler = MultiModalProfiler(processor)
dummy_data_v1 = (profiler.get_encoder_dummy_data(seq_len)
if is_encoder_data else
profiler.get_decoder_dummy_data(seq_len))
_seq_data = SequenceData.from_seqs(
dummy_data_v1.prompt_token_ids) # type: ignore[attr-defined]
dummy_data = DummyData(
seq_data=_seq_data,
multi_modal_data=getattr(dummy_data_v1, "multi_modal_data",
None),
multi_modal_placeholders=getattr(dummy_data_v1,
"multi_modal_placeholders",
None),
)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
else:
dummy_factory = self._get_dummy_data_factory(model_cls)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
dummy_factory,
overrides=model_config.mm_processor_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
dummy_data = dummy_factory(InputContext(model_config), seq_len,
_MultiModalCounts(mm_counts),
**mm_processor_kwargs)
# Having more tokens is over-conservative but otherwise fine
num_tokens = dummy_data.seq_data.prompt_token_ids
if len(num_tokens) < seq_len:
if is_encoder_data:
logger.warning_once(
f"Expected at least {seq_len} dummy encoder tokens for "
f"profiling, but found {len(num_tokens)} tokens instead.")
else:
raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.")
if (dummy_data.multi_modal_data is not None and
not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)):
for k, v in dummy_data.multi_modal_data.items():
num_items = len(v) if isinstance(v, list) else 1
num_expected = mm_counts[k]
assert num_items >= num_expected, (
f"Expected at least {num_expected} dummy '{k}' instances "
f"for profiling, but found {num_items} instances instead.")
return dummy_data
def _default_input_processor(
self,
ctx: InputContext,
inputs: ProcessorInputs,
**kwargs: object,
) -> ProcessorInputs:
"""The default input processor is a no-op."""
return inputs
def register_input_processor(self, processor: InputProcessor):
"""
Register an input processor to a model class.
The provided function is invoked on each input to the model. This
happens before
:meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`.
"""
def wrapper(model_cls: N) -> N:
if self._input_processors_by_model_type.contains(model_cls,
strict=True):
logger.warning(
"Model class %s already has input processor "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._input_processors_by_model_type[model_cls] = processor
return model_cls
if not model_config.is_multimodal_model:
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
return DummyData(seq_data=seq_data)
return wrapper
# Encoder dummy data does not contain multi-modal data
if is_encoder_data:
enc_data = mm_registry.get_encoder_dummy_data(
model_config, seq_len)
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
return DummyData(seq_data=seq_data)
def _get_model_input_processor(self, model_cls: type[nn.Module]):
return self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor)
def _ensure_mm_kwargs(
self,
inputs: SingletonInputs,
mm_processor_kwargs: dict[str, Any],
):
if inputs["type"] == "token":
# In case the input processor for that model fails to set it
if "mm_processor_kwargs" not in inputs:
inputs["mm_processor_kwargs"] = mm_processor_kwargs
elif inputs["type"] == "multimodal":
# Be more strict in V2
assert "mm_kwargs" in inputs
else:
assert_never(inputs["type"]) # type: ignore[arg-type]
def process_input(self, model_config: "ModelConfig",
inputs: ProcessorInputs) -> ProcessorInputs:
"""
Apply an input processor to an instance of model inputs.
The model is identified by ``model_config``.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
processor = self._get_model_input_processor(model_cls)
# Handle multimodal processor kwargs with priority:
# Inference kwargs -> Init kwargs -> {}
# If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs,
inputs.get("mm_processor_kwargs", {}), # type: ignore
processor,
requires_kw_only=False,
allow_var_kwargs=True,
)
dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len)
processed_inputs = processor(
InputContext(model_config),
inputs,
**mm_processor_kwargs,
return DummyData(
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
multi_modal_data=dec_data.multi_modal_data,
multi_modal_placeholders=dec_data.multi_modal_placeholders,
)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
if encoder_inputs is not None:
self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
if decoder_inputs is not None:
self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
return processed_inputs
def create_input_processor(self, model_config: "ModelConfig"):
"""
Create an input processor (see :meth:`_process_input`) for a
specific model.
"""
return functools.partial(self.process_input, model_config)
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import AbstractSet, Dict, Optional
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
logger = init_logger(__name__)
class LoRAResolver(ABC):
"""Base class for LoRA adapter resolvers.
This class defines the interface for resolving and fetching LoRA adapters.
Implementations of this class should handle the logic for locating and
downloading LoRA adapters from various sources (e.g. S3, cloud storage,
etc.).
"""
@abstractmethod
async def resolve_lora(self, base_model_name: str,
lora_name: str) -> Optional[LoRARequest]:
"""Abstract method to resolve and fetch a LoRA model adapter.
Implements logic to locate and download LoRA adapter based on the name.
Implementations might fetch from a blob storage or other sources.
Args:
base_model_name: The name/identifier of the base model to resolve.
lora_name: The name/identifier of the LoRA model to resolve.
Returns:
Optional[LoRARequest]: The resolved LoRA model information, or None
if the LoRA model cannot be found.
"""
pass
@dataclass
class _LoRAResolverRegistry:
resolvers: Dict[str, LoRAResolver] = field(default_factory=dict)
def get_supported_resolvers(self) -> AbstractSet[str]:
"""Get all registered resolver names."""
return self.resolvers.keys()
def register_resolver(
self,
resolver_name: str,
resolver: LoRAResolver,
) -> None:
"""Register a LoRA resolver.
Args:
resolver_name: Name to register the resolver under.
resolver: The LoRA resolver instance to register.
"""
if resolver_name in self.resolvers:
logger.warning(
"LoRA resolver %s is already registered, and will be "
"overwritten by the new resolver instance %s.", resolver_name,
resolver)
self.resolvers[resolver_name] = resolver
def get_resolver(self, resolver_name: str) -> LoRAResolver:
"""Get a registered resolver instance by name.
Args:
resolver_name: Name of the resolver to get.
Returns:
The resolver instance.
Raises:
KeyError: If the resolver is not found in the registry.
"""
if resolver_name not in self.resolvers:
raise KeyError(
f"LoRA resolver '{resolver_name}' not found. "
f"Available resolvers: {list(self.resolvers.keys())}")
return self.resolvers[resolver_name]
LoRAResolverRegistry = _LoRAResolverRegistry()
......@@ -114,7 +114,7 @@ def parse_fine_tuned_lora_name(
is_bias whether the tensor is lora bias.
"""
# LoRA weight qualified name always starts with `base_model.model.`,
# LoRA weight qualified name usually starts with `base_model.model.`,
# so we remove the prefix `base_model.model.` to make the following
# mapping correctly.
if "base_model.model." in name:
......@@ -123,18 +123,23 @@ def parse_fine_tuned_lora_name(
# recover the prefix `base_model.model.`
name = "base_model.model." + name
# In some situations, we may not start with `base_model.model.`.
# If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
# we should keep the prefix intact.
start_index = 2 if "base_model.model." in name else 0
parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A"
or parts[-2] == "lora_B"):
new_name = ".".join(parts[2:-2])
new_name = ".".join(parts[start_index:-2])
return new_name, parts[-2] == "lora_A", False
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
new_name = ".".join(parts[2:-1])
new_name = ".".join(parts[start_index:-1])
return new_name, parts[-1] == "lora_embedding_A", False
if parts[-1] == "bias":
new_name = ".".join(parts[2:-2])
new_name = ".".join(parts[start_index:-2])
return new_name, False, True
raise ValueError(f"{name} is unsupported LoRA weight")
......
......@@ -65,7 +65,7 @@ def maybe_backend_fallback(
fallback_or_error(
guided_params,
"xgrammar does not support advanced JSON schema features like "
"enums, patterns or numeric ranges.", "outlines")
"string length, item limits, or property bounds.", "outlines")
# xgrammar only supports GBNF grammars, so we must convert Lark.
# We must check if the grammar is likely Lark and if that
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment