Commit 96ae75ad authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev

parents f9f4a735 2339d59f
...@@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
@staticmethod @staticmethod
@functools.lru_cache @functools.lru_cache
def _log_prompt_logprob_unsupported_warning_once(): def _log_prompt_logprob_unsupported_warning_once():
# Reminder: Please update docs/source/usage/compatibility_matrix.rst # Reminder: Please update docs/source/usage/compatibility_matrix.md
# If the feature combo become valid # If the feature combo become valid
logger.warning( logger.warning(
"Prompt logprob is not supported by multi step workers. " "Prompt logprob is not supported by multi step workers. "
......
...@@ -21,7 +21,7 @@ from vllm.entrypoints.utils import with_cancellation ...@@ -21,7 +21,7 @@ from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server") logger = init_logger("vllm.entrypoints.api_server")
...@@ -119,6 +119,8 @@ async def run_server(args: Namespace, ...@@ -119,6 +119,8 @@ async def run_server(args: Namespace,
logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args) logger.info("args: %s", args)
set_ulimit()
app = await init_app(args, llm_engine) app = await init_app(args, llm_engine)
assert engine is not None assert engine is not None
......
...@@ -115,7 +115,7 @@ class LLM: ...@@ -115,7 +115,7 @@ class LLM:
integer, it is used as the level of compilation optimization. If it integer, it is used as the level of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration. is a dictionary, it can specify the full compilation configuration.
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`) :ref:`engine-args`)
Note: Note:
This class is intended to be used for offline inference. For online This class is intended to be used for offline inference. For online
...@@ -233,7 +233,8 @@ class LLM: ...@@ -233,7 +233,8 @@ class LLM:
self.request_counter = Counter() self.request_counter = Counter()
def __del__(self): def __del__(self):
if self.llm_engine and hasattr(self.llm_engine, "shutdown"): if hasattr(self, 'llm_engine') and self.llm_engine and hasattr(
self.llm_engine, "shutdown"):
self.llm_engine.shutdown() self.llm_engine.shutdown()
@staticmethod @staticmethod
...@@ -258,6 +259,13 @@ class LLM: ...@@ -258,6 +259,13 @@ class LLM:
else: else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
def get_default_sampling_params(self) -> SamplingParams:
diff_sampling_param = (
self.llm_engine.model_config.get_diff_sampling_param())
if diff_sampling_param:
return SamplingParams.from_optional(**diff_sampling_param)
return SamplingParams()
@overload @overload
def generate( def generate(
self, self,
...@@ -441,7 +449,7 @@ class LLM: ...@@ -441,7 +449,7 @@ class LLM:
if sampling_params is None: if sampling_params is None:
# Use default sampling params. # Use default sampling params.
sampling_params = SamplingParams() sampling_params = self.get_default_sampling_params()
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=parsed_prompts, prompts=parsed_prompts,
......
...@@ -27,6 +27,7 @@ from typing_extensions import assert_never ...@@ -27,6 +27,7 @@ from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
...@@ -44,8 +45,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -44,8 +45,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DetokenizeRequest, DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
EmbeddingRequest, EmbeddingRequest,
EmbeddingResponse, ErrorResponse, EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoraAdapterRequest, LoadLoraAdapterRequest,
PoolingRequest, PoolingResponse,
ScoreRequest, ScoreResponse, ScoreRequest, ScoreResponse,
TokenizeRequest, TokenizeRequest,
TokenizeResponse, TokenizeResponse,
...@@ -55,6 +59,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat ...@@ -55,6 +59,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization) OpenAIServingTokenization)
...@@ -63,14 +68,9 @@ from vllm.entrypoints.utils import with_cancellation ...@@ -63,14 +68,9 @@ from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address) is_valid_ipv6_address, set_ulimit)
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
if envs.VLLM_USE_V1:
from vllm.v1.engine.async_llm import AsyncLLMEngine # type: ignore
else:
from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
prometheus_multiproc_dir: tempfile.TemporaryDirectory prometheus_multiproc_dir: tempfile.TemporaryDirectory
...@@ -288,6 +288,10 @@ def completion(request: Request) -> Optional[OpenAIServingCompletion]: ...@@ -288,6 +288,10 @@ def completion(request: Request) -> Optional[OpenAIServingCompletion]:
return request.app.state.openai_serving_completion return request.app.state.openai_serving_completion
def pooling(request: Request) -> Optional[OpenAIServingPooling]:
return request.app.state.openai_serving_pooling
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding return request.app.state.openai_serving_embedding
...@@ -399,10 +403,36 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -399,10 +403,36 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request) handler = embedding(raw_request)
if handler is None: if handler is None:
return base(raw_request).create_error_response( fallback_handler = pooling(raw_request)
message="The model does not support Embeddings API") if fallback_handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
logger.warning(
"Embeddings API will become exclusive to embedding models "
"in a future release. To return the hidden states directly, "
"use the Pooling API (`/pooling`) instead.")
res = await fallback_handler.create_pooling(request, raw_request)
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
object=res.object,
created=res.created,
model=res.model,
data=[
EmbeddingResponseData(
index=d.index,
embedding=d.data, # type: ignore
) for d in res.data
],
usage=res.usage,
)
else:
generator = res
else:
generator = await handler.create_embedding(request, raw_request)
generator = await handler.create_embedding(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
...@@ -412,6 +442,24 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): ...@@ -412,6 +442,24 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator) assert_never(generator)
@router.post("/pooling")
@with_cancellation
async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Pooling API")
generator = await handler.create_pooling(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, PoolingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/score") @router.post("/score")
@with_cancellation @with_cancellation
async def create_score(request: ScoreRequest, raw_request: Request): async def create_score(request: ScoreRequest, raw_request: Request):
...@@ -537,12 +585,18 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -537,12 +585,18 @@ def build_app(args: Namespace) -> FastAPI:
status_code=401) status_code=401)
return await call_next(request) return await call_next(request)
@app.middleware("http") if args.enable_request_id_headers:
async def add_request_id(request: Request, call_next): logger.warning(
request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex "CAUTION: Enabling X-Request-Id headers in the API Server. "
response = await call_next(request) "This can harm performance at high QPS.")
response.headers["X-Request-Id"] = request_id
return response @app.middleware("http")
async def add_request_id(request: Request, call_next):
request_id = request.headers.get(
"X-Request-Id") or uuid.uuid4().hex
response = await call_next(request)
response.headers["X-Request-Id"] = request_id
return response
for middleware in args.middleware: for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1) module_path, object_name = middleware.rsplit(".", 1)
...@@ -609,7 +663,7 @@ def init_app_state( ...@@ -609,7 +663,7 @@ def init_app_state(
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) if model_config.runner_type == "generate" else None ) if model_config.runner_type == "generate" else None
state.openai_serving_embedding = OpenAIServingEmbedding( state.openai_serving_pooling = OpenAIServingPooling(
engine_client, engine_client,
model_config, model_config,
base_model_paths, base_model_paths,
...@@ -617,13 +671,20 @@ def init_app_state( ...@@ -617,13 +671,20 @@ def init_app_state(
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
) if model_config.runner_type == "pooling" else None ) if model_config.runner_type == "pooling" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None
state.openai_serving_scores = OpenAIServingScores( state.openai_serving_scores = OpenAIServingScores(
engine_client, engine_client,
model_config, model_config,
base_model_paths, base_model_paths,
request_logger=request_logger request_logger=request_logger
) if (model_config.runner_type == "pooling" \ ) if model_config.task == "score" else None
and model_config.is_cross_encoder) else None
state.openai_serving_tokenization = OpenAIServingTokenization( state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client, engine_client,
model_config, model_config,
...@@ -666,6 +727,10 @@ async def run_server(args, **uvicorn_kwargs) -> None: ...@@ -666,6 +727,10 @@ async def run_server(args, **uvicorn_kwargs) -> None:
sock_addr = (args.host or "", args.port) sock_addr = (args.host or "", args.port)
sock = create_server_socket(sock_addr) sock = create_server_socket(sock_addr)
# workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active
set_ulimit()
def signal_handler(*_) -> None: def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing # Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated") raise KeyboardInterrupt("terminated")
......
...@@ -196,7 +196,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ...@@ -196,7 +196,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action="store_true", action="store_true",
help="If specified, will run the OpenAI frontend server in the same " help="If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine.") "process as the model serving engine.")
parser.add_argument(
"--enable-request-id-headers",
action="store_true",
help="If specified, API server will add X-Request-Id header to "
"responses. Caution: this hurts performance at high QPS.")
parser.add_argument( parser.add_argument(
"--enable-auto-tool-choice", "--enable-auto-tool-choice",
action="store_true", action="store_true",
......
...@@ -46,7 +46,15 @@ class OpenAIBaseModel(BaseModel): ...@@ -46,7 +46,15 @@ class OpenAIBaseModel(BaseModel):
@classmethod @classmethod
def __log_extra_fields__(cls, data): def __log_extra_fields__(cls, data):
if isinstance(data, dict): if isinstance(data, dict):
extra_fields = data.keys() - cls.model_fields.keys() # Get all class field names and their potential aliases
field_names = set()
for field_name, field in cls.model_fields.items():
field_names.add(field_name)
if hasattr(field, 'alias') and field.alias:
field_names.add(field.alias)
# Compare against both field names and aliases
extra_fields = data.keys() - field_names
if extra_fields: if extra_fields:
logger.warning( logger.warning(
"The following fields were present in the request " "The following fields were present in the request "
...@@ -211,8 +219,8 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -211,8 +219,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = 1.0 temperature: Optional[float] = None
top_p: Optional[float] = 1.0 top_p: Optional[float] = None
tools: Optional[List[ChatCompletionToolsParam]] = None tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"], Literal["auto"], tool_choice: Optional[Union[Literal["none"], Literal["auto"],
ChatCompletionNamedToolChoiceParam]] = "none" ChatCompletionNamedToolChoiceParam]] = "none"
...@@ -224,9 +232,9 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -224,9 +232,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: begin-chat-completion-sampling-params # doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None best_of: Optional[int] = None
use_beam_search: bool = False use_beam_search: bool = False
top_k: int = -1 top_k: Optional[int] = None
min_p: float = 0.0 min_p: Optional[float] = None
repetition_penalty: float = 1.0 repetition_penalty: Optional[float] = None
length_penalty: float = 1.0 length_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
...@@ -348,15 +356,32 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -348,15 +356,32 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
def to_beam_search_params(self, # Default sampling parameters for chat completion requests
default_max_tokens: int) -> BeamSearchParams: _DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API # TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1 n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
return BeamSearchParams( return BeamSearchParams(
beam_width=n, beam_width=n,
...@@ -367,13 +392,36 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -367,13 +392,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
include_stop_str_in_output=self.include_stop_str_in_output) include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params( def to_sampling_params(
self, default_max_tokens: int, self,
logits_processor_pattern: Optional[str]) -> SamplingParams: default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API # TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
prompt_logprobs = self.prompt_logprobs prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo: if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs prompt_logprobs = self.top_logprobs
...@@ -403,11 +451,11 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -403,11 +451,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
best_of=self.best_of, best_of=self.best_of,
presence_penalty=self.presence_penalty, presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty, frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty, repetition_penalty=repetition_penalty,
temperature=self.temperature, temperature=temperature,
top_p=self.top_p, top_p=top_p,
top_k=self.top_k, top_k=top_k,
min_p=self.min_p, min_p=min_p,
seed=self.seed, seed=self.seed,
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
...@@ -584,15 +632,15 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -584,15 +632,15 @@ class CompletionRequest(OpenAIBaseModel):
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None suffix: Optional[str] = None
temperature: Optional[float] = 1.0 temperature: Optional[float] = None
top_p: Optional[float] = 1.0 top_p: Optional[float] = None
user: Optional[str] = None user: Optional[str] = None
# doc: begin-completion-sampling-params # doc: begin-completion-sampling-params
use_beam_search: bool = False use_beam_search: bool = False
top_k: int = -1 top_k: Optional[int] = None
min_p: float = 0.0 min_p: Optional[float] = None
repetition_penalty: float = 1.0 repetition_penalty: Optional[float] = None
length_penalty: float = 1.0 length_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
...@@ -669,14 +717,30 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -669,14 +717,30 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params # doc: end-completion-extra-params
def to_beam_search_params(self, # Default sampling parameters for completion requests
default_max_tokens: int) -> BeamSearchParams: _DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1 n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0)
return BeamSearchParams( return BeamSearchParams(
beam_width=n, beam_width=n,
...@@ -687,12 +751,35 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -687,12 +751,35 @@ class CompletionRequest(OpenAIBaseModel):
include_stop_str_in_output=self.include_stop_str_in_output) include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params( def to_sampling_params(
self, default_max_tokens: int, self,
logits_processor_pattern: Optional[str]) -> SamplingParams: default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
prompt_logprobs = self.prompt_logprobs prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo: if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs prompt_logprobs = self.logprobs
...@@ -718,11 +805,11 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -718,11 +805,11 @@ class CompletionRequest(OpenAIBaseModel):
best_of=self.best_of, best_of=self.best_of,
presence_penalty=self.presence_penalty, presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty, frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty, repetition_penalty=repetition_penalty,
temperature=self.temperature, temperature=temperature,
top_p=self.top_p, top_p=top_p,
top_k=self.top_k, top_k=top_k,
min_p=self.min_p, min_p=min_p,
seed=self.seed, seed=self.seed,
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
...@@ -876,6 +963,10 @@ class EmbeddingChatRequest(OpenAIBaseModel): ...@@ -876,6 +963,10 @@ class EmbeddingChatRequest(OpenAIBaseModel):
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]
class ScoreRequest(OpenAIBaseModel): class ScoreRequest(OpenAIBaseModel):
model: str model: str
...@@ -971,6 +1062,21 @@ class EmbeddingResponse(OpenAIBaseModel): ...@@ -971,6 +1062,21 @@ class EmbeddingResponse(OpenAIBaseModel):
usage: UsageInfo usage: UsageInfo
class PoolingResponseData(OpenAIBaseModel):
index: int
object: str = "pooling"
data: Union[List[List[float]], List[float], str]
class PoolingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: List[PoolingResponseData]
usage: UsageInfo
class ScoreResponseData(OpenAIBaseModel): class ScoreResponseData(OpenAIBaseModel):
index: int index: int
object: str = "score" object: str = "score"
......
...@@ -232,7 +232,7 @@ async def main(args): ...@@ -232,7 +232,7 @@ async def main(args):
request_logger=request_logger, request_logger=request_logger,
chat_template=None, chat_template=None,
chat_template_content_format="auto", chat_template_content_format="auto",
) if model_config.runner_type == "pooling" else None ) if model_config.task == "embed" else None
tracker = BatchProgressTracker() tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file) logger.info("Reading batch from %s...", args.input_file)
......
...@@ -91,6 +91,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -91,6 +91,10 @@ class OpenAIServingChat(OpenAIServing):
"been registered") from e "been registered") from e
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info("Overwriting default chat sampling param with: %s",
diff_sampling_param)
async def create_chat_completion( async def create_chat_completion(
self, self,
...@@ -191,13 +195,17 @@ class OpenAIServingChat(OpenAIServing): ...@@ -191,13 +195,17 @@ class OpenAIServingChat(OpenAIServing):
sampling_params: Union[SamplingParams, BeamSearchParams] sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len( default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]) engine_prompt["prompt_token_ids"])
# Build default sampling params
default_sampling_params = (
self.model_config.get_diff_sampling_param())
if request.use_beam_search: if request.use_beam_search:
sampling_params = request.to_beam_search_params( sampling_params = request.to_beam_search_params(
default_max_tokens) default_max_tokens, default_sampling_params)
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, default_max_tokens,
self.model_config.logits_processor_pattern) self.model_config.logits_processor_pattern,
default_sampling_params)
self._log_inputs(request_id, self._log_inputs(request_id,
request_prompts[i], request_prompts[i],
......
...@@ -55,6 +55,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -55,6 +55,11 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapters=prompt_adapters, prompt_adapters=prompt_adapters,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids) return_tokens_as_token_ids=return_tokens_as_token_ids)
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info(
"Overwriting default completion sampling param with: %s",
diff_sampling_param)
async def create_completion( async def create_completion(
self, self,
...@@ -118,13 +123,17 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -118,13 +123,17 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params: Union[SamplingParams, BeamSearchParams] sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len( default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]) engine_prompt["prompt_token_ids"])
# Build default sampling params
default_sampling_params = (
self.model_config.get_diff_sampling_param())
if request.use_beam_search: if request.use_beam_search:
sampling_params = request.to_beam_search_params( sampling_params = request.to_beam_search_params(
default_max_tokens) default_max_tokens, default_sampling_params)
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, default_max_tokens,
self.model_config.logits_processor_pattern) self.model_config.logits_processor_pattern,
default_sampling_params)
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
......
...@@ -40,36 +40,6 @@ def _get_embedding( ...@@ -40,36 +40,6 @@ def _get_embedding(
assert_never(encoding_format) assert_never(encoding_format)
def request_output_to_embedding_response(
final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str,
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
prompt_token_ids = final_res.prompt_token_ids
embedding = _get_embedding(embedding_res.outputs, encoding_format)
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=data,
usage=usage,
)
class OpenAIServingEmbedding(OpenAIServing): class OpenAIServingEmbedding(OpenAIServing):
def __init__( def __init__(
...@@ -114,7 +84,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -114,7 +84,7 @@ class OpenAIServingEmbedding(OpenAIServing):
model_name = request.model model_name = request.model
request_id = f"embd-{self._base_request_id(raw_request)}" request_id = f"embd-{self._base_request_id(raw_request)}"
created_time = int(time.monotonic()) created_time = int(time.time())
truncate_prompt_tokens = None truncate_prompt_tokens = None
...@@ -218,9 +188,13 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -218,9 +188,13 @@ class OpenAIServingEmbedding(OpenAIServing):
final_res_batch_checked = cast(List[PoolingRequestOutput], final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch) final_res_batch)
response = request_output_to_embedding_response( response = self.request_output_to_embedding_response(
final_res_batch_checked, request_id, created_time, model_name, final_res_batch_checked,
encoding_format) request_id,
created_time,
model_name,
encoding_format,
)
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
...@@ -228,3 +202,40 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -228,3 +202,40 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return response return response
def request_output_to_embedding_response(
self,
final_res_batch: List[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["float", "base64"],
) -> EmbeddingResponse:
items: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
item = EmbeddingResponseData(
index=idx,
embedding=_get_embedding(embedding_res.outputs,
encoding_format),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
import asyncio
import base64
import time
from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
import numpy as np
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse,
PoolingChatRequest,
PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
def _get_data(
output: PoolingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
return output.data.tolist()
elif encoding_format == "base64":
# Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
pooling_bytes = np.array(output.data, dtype="float32").tobytes()
return base64.b64encode(pooling_bytes).decode("utf-8")
assert_never(encoding_format)
class OpenAIServingPooling(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
async def create_pooling(
self,
request: PoolingRequest,
raw_request: Optional[Request] = None,
) -> Union[PoolingResponse, ErrorResponse]:
"""
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = request.model
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = None
if request.truncate_prompt_tokens is not None:
if request.truncate_prompt_tokens <= self.max_model_len:
truncate_prompt_tokens = request.truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for pooling models")
if isinstance(request, PoolingChatRequest):
(
_,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
# In pooling requests, we are not generating tokens,
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = self.request_output_to_pooling_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
encoding_format,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def request_output_to_pooling_response(
self,
final_res_batch: List[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["float", "base64"],
) -> PoolingResponse:
items: List[PoolingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
item = PoolingResponseData(
index=idx,
data=_get_data(final_res.outputs, encoding_format),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return PoolingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
...@@ -20,32 +20,6 @@ from vllm.utils import make_async, merge_async_iterators ...@@ -20,32 +20,6 @@ from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
def request_output_to_score_response(
final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str) -> ScoreResponse:
data: List[ScoreResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
score_data = ScoreResponseData(index=idx,
score=classify_res.outputs.score)
data.append(score_data)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ScoreResponse(
id=request_id,
created=created_time,
model=model_name,
data=data,
usage=usage,
)
def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str], def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str],
str]) -> List: str]) -> List:
if isinstance(text_1, (str, dict)): if isinstance(text_1, (str, dict)):
...@@ -103,7 +77,7 @@ class OpenAIServingScores(OpenAIServing): ...@@ -103,7 +77,7 @@ class OpenAIServingScores(OpenAIServing):
model_name = request.model model_name = request.model
request_id = f"score-{self._base_request_id(raw_request)}" request_id = f"score-{self._base_request_id(raw_request)}"
created_time = int(time.monotonic()) created_time = int(time.time())
truncate_prompt_tokens = request.truncate_prompt_tokens truncate_prompt_tokens = request.truncate_prompt_tokens
request_prompts = [] request_prompts = []
...@@ -203,8 +177,12 @@ class OpenAIServingScores(OpenAIServing): ...@@ -203,8 +177,12 @@ class OpenAIServingScores(OpenAIServing):
final_res_batch_checked = cast(List[PoolingRequestOutput], final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch) final_res_batch)
response = request_output_to_score_response( response = self.request_output_to_score_response(
final_res_batch_checked, request_id, created_time, model_name) final_res_batch_checked,
request_id,
created_time,
model_name,
)
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
...@@ -212,3 +190,38 @@ class OpenAIServingScores(OpenAIServing): ...@@ -212,3 +190,38 @@ class OpenAIServingScores(OpenAIServing):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return response return response
def request_output_to_score_response(
self,
final_res_batch: List[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
) -> ScoreResponse:
items: List[ScoreResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
item = ScoreResponseData(
index=idx,
score=classify_res.outputs.score,
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ScoreResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
...@@ -35,13 +35,18 @@ class GraniteToolParser(ToolParser): ...@@ -35,13 +35,18 @@ class GraniteToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer) super().__init__(tokenizer)
# for granite 3.0, the token `<|tool_call|>`
self.bot_token = "<|tool_call|>" self.bot_token = "<|tool_call|>"
# for granite 3.1, the string `<tool_call>`
self.bot_string = "<tool_call>"
def extract_tool_calls( def extract_tool_calls(
self, model_output: str, self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation: request: ChatCompletionRequest) -> ExtractedToolCallInformation:
# remove whitespace and the BOT token if it exists stripped = model_output.strip()\
stripped = model_output.strip().removeprefix(self.bot_token).lstrip() .removeprefix(self.bot_token)\
.removeprefix(self.bot_string)\
.lstrip()
if not stripped or stripped[0] != '[': if not stripped or stripped[0] != '[':
return ExtractedToolCallInformation(tools_called=False, return ExtractedToolCallInformation(tools_called=False,
tool_calls=[], tool_calls=[],
...@@ -91,6 +96,9 @@ class GraniteToolParser(ToolParser): ...@@ -91,6 +96,9 @@ class GraniteToolParser(ToolParser):
if current_text[start_idx:].startswith(self.bot_token): if current_text[start_idx:].startswith(self.bot_token):
start_idx = consume_space(start_idx + len(self.bot_token), start_idx = consume_space(start_idx + len(self.bot_token),
current_text) current_text)
if current_text[start_idx:].startswith(self.bot_string):
start_idx = consume_space(start_idx + len(self.bot_string),
current_text)
if not current_text or start_idx >= len(current_text)\ if not current_text or start_idx >= len(current_text)\
or current_text[start_idx] != '[': or current_text[start_idx] != '[':
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
......
...@@ -35,7 +35,7 @@ if TYPE_CHECKING: ...@@ -35,7 +35,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_PP_LAYER_PARTITION: Optional[str] = None
...@@ -308,7 +308,8 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -308,7 +308,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vllm will use flashinfer sampler # If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER": "VLLM_USE_FLASHINFER_SAMPLER":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))), lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
# If set, vllm will force flashinfer to use tensor cores; # If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture. # otherwise will use heuristic based on model architecture.
......
...@@ -22,7 +22,7 @@ class CPUExecutor(ExecutorBase): ...@@ -22,7 +22,7 @@ class CPUExecutor(ExecutorBase):
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu" assert self.device_config.device_type == "cpu"
# Reminder: Please update docs/source/usage/compatibility_matrix.rst # Reminder: Please update docs/source/usage/compatibility_matrix.md
# If the feature combo become valid # If the feature combo become valid
assert self.lora_config is None, "cpu backend doesn't support LoRA" assert self.lora_config is None, "cpu backend doesn't support LoRA"
......
...@@ -123,6 +123,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -123,6 +123,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the workers. # Create the workers.
driver_ip = get_ip() driver_ip = get_ip()
workers = []
for bundle_id, bundle in enumerate(placement_group.bundle_specs): for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0): if not bundle.get("GPU", 0):
continue continue
...@@ -138,20 +139,30 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -138,20 +139,30 @@ class RayGPUExecutor(DistributedGPUExecutor):
scheduling_strategy=scheduling_strategy, scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config) )(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
workers.append(worker)
if self.use_ray_spmd_worker: worker_ip_refs = [
self.workers.append(worker) worker.get_node_ip.remote() # type: ignore[attr-defined]
else: for worker in workers
worker_ip = ray.get(worker.get_node_ip.remote()) ]
if worker_ip == driver_ip and self.driver_dummy_worker is None: worker_ips = ray.get(worker_ip_refs)
if not self.use_ray_spmd_worker:
for i in range(len(workers)):
worker = workers[i]
worker_ip = worker_ips[i]
if self.driver_dummy_worker is None and worker_ip == driver_ip:
# If the worker is on the same node as the driver, we use it # If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process. # as the resource holder for the driver process.
self.driver_dummy_worker = worker self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper( self.driver_worker = RayWorkerWrapper(
vllm_config=self.vllm_config) vllm_config=self.vllm_config)
else: workers.pop(i)
# Else, added to the list of workers. worker_ips.pop(i)
self.workers.append(worker) self.workers = workers
break
else:
self.workers = workers
logger.debug("workers: %s", self.workers) logger.debug("workers: %s", self.workers)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
...@@ -161,14 +172,12 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -161,14 +172,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
"adjusting the Ray placement group or running the driver on a " "adjusting the Ray placement group or running the driver on a "
"GPU node.") "GPU node.")
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {} ip_counts: Dict[str, int] = {}
for ip in worker_ips: for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1 ip_counts[ip] = ip_counts.get(ip, 0) + 1
worker_to_ip = dict(zip(self.workers, worker_ips))
def sort_by_driver_then_worker_ip(worker): def sort_by_driver_then_worker_ip(worker):
""" """
Sort the workers based on 3 properties: Sort the workers based on 3 properties:
...@@ -179,7 +188,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -179,7 +188,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
3. Finally, if the work is on a node with smaller IP address, it 3. Finally, if the work is on a node with smaller IP address, it
should be placed first. should be placed first.
""" """
ip = ray.get(worker.get_node_ip.remote()) ip = worker_to_ip[worker]
return (ip != driver_ip, ip_counts[ip], ip) return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be # After sorting, the workers on the same node will be
......
...@@ -13,7 +13,7 @@ The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine` ...@@ -13,7 +13,7 @@ The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine`
to dispatch data processing according to the target model. to dispatch data processing according to the target model.
See also: See also:
:ref:`input_processing_pipeline` :ref:`input-processing-pipeline`
""" """
__all__ = [ __all__ = [
......
...@@ -162,6 +162,11 @@ class TokenInputs(TypedDict): ...@@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
Placeholder ranges for the multi-modal data. Placeholder ranges for the multi-modal data.
""" """
multi_modal_hashes: NotRequired[List[str]]
"""
The hashes of the multi-modal data.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]] mm_processor_kwargs: NotRequired[Dict[str, Any]]
""" """
Optional multi-modal processor kwargs to be forwarded to the Optional multi-modal processor kwargs to be forwarded to the
...@@ -177,6 +182,7 @@ def token_inputs( ...@@ -177,6 +182,7 @@ def token_inputs(
prompt: Optional[str] = None, prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_inputs: Optional["MultiModalKwargs"] = None, multi_modal_inputs: Optional["MultiModalKwargs"] = None,
multi_modal_hashes: Optional[List[str]] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs: ) -> TokenInputs:
...@@ -191,6 +197,8 @@ def token_inputs( ...@@ -191,6 +197,8 @@ def token_inputs(
inputs["multi_modal_data"] = multi_modal_data inputs["multi_modal_data"] = multi_modal_data
if multi_modal_inputs is not None: if multi_modal_inputs is not None:
inputs["multi_modal_inputs"] = multi_modal_inputs inputs["multi_modal_inputs"] = multi_modal_inputs
if multi_modal_hashes is not None:
inputs["multi_modal_hashes"] = multi_modal_hashes
if multi_modal_placeholders is not None: if multi_modal_placeholders is not None:
inputs["multi_modal_placeholders"] = multi_modal_placeholders inputs["multi_modal_placeholders"] = multi_modal_placeholders
if mm_processor_kwargs is not None: if mm_processor_kwargs is not None:
...@@ -295,6 +303,18 @@ class SingletonInputsAdapter: ...@@ -295,6 +303,18 @@ class SingletonInputsAdapter:
assert_never(inputs) assert_never(inputs)
@cached_property
def multi_modal_hashes(self) -> List[str]:
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("multi_modal_hashes", [])
if inputs["type"] == "multimodal":
return inputs.get("mm_hashes", [])
assert_never(inputs)
@cached_property @cached_property
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
inputs = self.inputs inputs = self.inputs
......
import functools import functools
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple, from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple,
Optional, Protocol, Type) Optional, Protocol, Union)
from torch import nn from torch import nn
from transformers import PretrainedConfig, ProcessorMixin from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar, assert_never from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -26,6 +26,7 @@ if TYPE_CHECKING: ...@@ -26,6 +26,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin)
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -38,24 +39,28 @@ class InputContext: ...@@ -38,24 +39,28 @@ class InputContext:
model_config: "ModelConfig" model_config: "ModelConfig"
"""The configuration of the model.""" """The configuration of the model."""
def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C: def get_hf_config(
self,
typ: Union[type[C], tuple[type[C], ...]] = PretrainedConfig,
/,
) -> C:
""" """
Get the HuggingFace configuration Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model, (:class:`transformers.PretrainedConfig`) of the model,
additionally checking its type. additionally checking its type.
Raises: Raises:
TypeError: If the model is not of the specified type. TypeError: If the configuration is not of the specified type.
""" """
hf_config = self.model_config.hf_config hf_config = self.model_config.hf_config
if not isinstance(hf_config, hf_config_type): if not isinstance(hf_config, typ):
raise TypeError("Invalid type of HuggingFace config. " raise TypeError("Invalid type of HuggingFace config. "
f"Expected type: {hf_config_type}, but " f"Expected type: {typ}, but "
f"found type: {type(hf_config)}") f"found type: {type(hf_config)}")
return hf_config return hf_config
def get_hf_image_processor_config(self) -> Dict[str, Any]: def get_hf_image_processor_config(self) -> dict[str, Any]:
""" """
Get the HuggingFace image processor configuration of the model. Get the HuggingFace image processor configuration of the model.
""" """
...@@ -74,18 +79,37 @@ class InputContext: ...@@ -74,18 +79,37 @@ class InputContext:
return mm_config return mm_config
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: def get_hf_processor(
self,
typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> P:
"""
Get the HuggingFace processor
(:class:`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
base_kwargs = self.model_config.mm_processor_kwargs base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None: if base_kwargs is None:
base_kwargs = {} base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs} merged_kwargs = {**base_kwargs, **kwargs}
return cached_get_processor( hf_processor = cached_get_processor(
self.model_config.model, self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs, **merged_kwargs,
) )
if not isinstance(hf_processor, typ):
raise TypeError("Invalid type of HuggingFace processor. "
f"Expected type: {typ}, but "
f"found type: {type(hf_processor)}")
return hf_processor
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -93,39 +117,55 @@ class InputProcessingContext(InputContext): ...@@ -93,39 +117,55 @@ class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs.""" """The tokenizer used to tokenize the inputs."""
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: def get_hf_processor(
base_kwargs = self.model_config.mm_processor_kwargs self,
if base_kwargs is None: typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
base_kwargs = {} /,
**kwargs: object,
merged_kwargs = {**base_kwargs, **kwargs} ) -> P:
return super().get_hf_processor(
return cached_get_processor( typ,
self.model_config.model, tokenizer=self.tokenizer,
tokenizer=self.tokenizer, # Override the tokenizer with ours **kwargs,
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
) )
def resolve_hf_processor_call_kwargs( def call_hf_processor(
self, self,
hf_processor: ProcessorMixin, hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
inference_kwargs: Mapping[str, object], inference_kwargs: Mapping[str, object],
) -> Mapping[str, object]: ) -> BatchFeature:
assert callable(hf_processor) assert callable(hf_processor)
base_kwargs = self.model_config.mm_processor_kwargs base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None: if base_kwargs is None:
base_kwargs = {} base_kwargs = {}
return resolve_mm_processor_kwargs( merged_kwargs = resolve_mm_processor_kwargs(
base_kwargs, base_kwargs,
inference_kwargs, inference_kwargs,
hf_processor, hf_processor,
requires_kw_only=False,
allow_var_kwargs=True,
) )
try:
return hf_processor(
text=prompt,
**processor_data,
**merged_kwargs,
return_tensors="pt",
)
except Exception as exc:
data = dict(text=prompt, **processor_data)
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={merged_kwargs}")
raise RuntimeError(msg) from exc
N = TypeVar("N", bound=Type[nn.Module]) N = TypeVar("N", bound=type[nn.Module])
class DummyData(NamedTuple): class DummyData(NamedTuple):
...@@ -232,7 +272,7 @@ class InputRegistry: ...@@ -232,7 +272,7 @@ class InputRegistry:
return wrapper return wrapper
def _get_dummy_data_factory(self, model_cls: Type[nn.Module]): def _get_dummy_data_factory(self, model_cls: type[nn.Module]):
return self._dummy_factories_by_model_type \ return self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory) .get(model_cls, self._default_dummy_data_factory)
...@@ -257,7 +297,7 @@ class InputRegistry: ...@@ -257,7 +297,7 @@ class InputRegistry:
return wrapper return wrapper
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]): def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]):
return self._dummy_encoder_factories_by_model_type \ return self._dummy_encoder_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory) .get(model_cls, self._default_dummy_data_factory)
...@@ -274,7 +314,7 @@ class InputRegistry: ...@@ -274,7 +314,7 @@ class InputRegistry:
The model is identified by ``model_config``. The model is identified by ``model_config``.
See also: See also:
:ref:`enabling_multimodal_inputs` :ref:`enabling-multimodal-inputs`
Note: Note:
This should be called after This should be called after
...@@ -351,7 +391,7 @@ class InputRegistry: ...@@ -351,7 +391,7 @@ class InputRegistry:
happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`. happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`.
See also: See also:
:ref:`input_processing_pipeline` :ref:`input-processing-pipeline`
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
...@@ -368,14 +408,14 @@ class InputRegistry: ...@@ -368,14 +408,14 @@ class InputRegistry:
return wrapper return wrapper
def _get_model_input_processor(self, model_cls: Type[nn.Module]): def _get_model_input_processor(self, model_cls: type[nn.Module]):
return self._input_processors_by_model_type \ return self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor) .get(model_cls, self._default_input_processor)
def _ensure_mm_kwargs( def _ensure_mm_kwargs(
self, self,
inputs: SingletonInputs, inputs: SingletonInputs,
mm_processor_kwargs: Dict[str, Any], mm_processor_kwargs: dict[str, Any],
): ):
if inputs["type"] == "token": if inputs["type"] == "token":
# In case the input processor for that model fails to set it # In case the input processor for that model fails to set it
...@@ -395,7 +435,7 @@ class InputRegistry: ...@@ -395,7 +435,7 @@ class InputRegistry:
The model is identified by ``model_config``. The model is identified by ``model_config``.
See also: See also:
:ref:`input_processing_pipeline` :ref:`input-processing-pipeline`
""" """
# Avoid circular import # Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture from vllm.model_executor.model_loader import get_model_architecture
......
...@@ -425,8 +425,9 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -425,8 +425,9 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
if self.base_layer.skip_bias_add else None) if self.base_layer.skip_bias_add else None)
return output, output_bias return output, output_bias
# ReplicatedLinear should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@classmethod @classmethod
@_not_fully_sharded_can_replace
def can_replace_layer( def can_replace_layer(
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment