Unverified Commit 06e0bc21 authored by Sage's avatar Sage Committed by GitHub
Browse files

[Frontend] Split `OpenAIServingModels` into `OpenAIModelRegistry` + `OpenAIServingModels` (#36536)


Signed-off-by: default avatarSage Ahrac <sagiahrak@gmail.com>
parent 5a71cdd7
...@@ -414,11 +414,19 @@ async def init_render_app_state( ...@@ -414,11 +414,19 @@ async def init_render_app_state(
directly from the :class:`~vllm.config.VllmConfig`. directly from the :class:`~vllm.config.VllmConfig`.
""" """
from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.renderers import renderer_from_config from vllm.renderers import renderer_from_config
served_model_names = args.served_model_name or [args.model] served_model_names = args.served_model_name or [args.model]
model_registry = OpenAIModelRegistry(
model_config=vllm_config.model_config,
base_model_paths=[
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
],
)
if args.enable_log_requests: if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len) request_logger = RequestLogger(max_log_len=args.max_log_len)
...@@ -435,7 +443,7 @@ async def init_render_app_state( ...@@ -435,7 +443,7 @@ async def init_render_app_state(
model_config=vllm_config.model_config, model_config=vllm_config.model_config,
renderer=renderer, renderer=renderer,
io_processor=io_processor, io_processor=io_processor,
served_model_names=served_model_names, model_registry=model_registry,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
...@@ -447,8 +455,7 @@ async def init_render_app_state( ...@@ -447,8 +455,7 @@ async def init_render_app_state(
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
) )
# Expose models endpoint via the render handler. state.openai_serving_models = model_registry
state.openai_serving_models = state.openai_serving_render
state.vllm_config = vllm_config state.vllm_config = vllm_config
# Disable stats logging — there is no engine to poll. # Disable stats logging — there is no engine to poll.
......
...@@ -169,9 +169,7 @@ async def init_generate_state( ...@@ -169,9 +169,7 @@ async def init_generate_state(
model_config=engine_client.model_config, model_config=engine_client.model_config,
renderer=engine_client.renderer, renderer=engine_client.renderer,
io_processor=engine_client.io_processor, io_processor=engine_client.io_processor,
served_model_names=[ model_registry=state.openai_serving_models.registry,
mp.name for mp in state.openai_serving_models.base_model_paths
],
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
......
...@@ -5,6 +5,7 @@ from asyncio import Lock ...@@ -5,6 +5,7 @@ from asyncio import Lock
from collections import defaultdict from collections import defaultdict
from http import HTTPStatus from http import HTTPStatus
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse, ErrorResponse,
...@@ -27,6 +28,51 @@ from vllm.utils.counter import AtomicCounter ...@@ -27,6 +28,51 @@ from vllm.utils.counter import AtomicCounter
logger = init_logger(__name__) logger = init_logger(__name__)
class OpenAIModelRegistry:
"""Read-only view of the loaded base models with no engine dependency.
Suitable for CPU-only / render-only contexts that have no engine client
and no LoRA support.
"""
def __init__(
self,
model_config: ModelConfig,
base_model_paths: list[BaseModelPath],
) -> None:
self.model_config = model_config
self.base_model_paths = base_model_paths
def is_base_model(self, model_name: str) -> bool:
return any(model.name == model_name for model in self.base_model_paths)
async def check_model(self, model_name: str | None) -> ErrorResponse | None:
"""Return an ErrorResponse if model_name is not served, else None."""
if not model_name or self.is_base_model(model_name):
return None
return create_error_response(
message=f"The model `{model_name}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
param="model",
)
async def show_available_models(self) -> ModelList:
"""Show available models (base models only)."""
max_model_len = self.model_config.max_model_len
return ModelList(
data=[
ModelCard(
id=base_model.name,
max_model_len=max_model_len,
root=base_model.model_path,
permission=[ModelPermission()],
)
for base_model in self.base_model_paths
]
)
class OpenAIServingModels: class OpenAIServingModels:
"""Shared instance to hold data about the loaded base model(s) and adapters. """Shared instance to hold data about the loaded base model(s) and adapters.
...@@ -45,6 +91,11 @@ class OpenAIServingModels: ...@@ -45,6 +91,11 @@ class OpenAIServingModels:
): ):
super().__init__() super().__init__()
self.registry = OpenAIModelRegistry(
model_config=engine_client.model_config,
base_model_paths=base_model_paths,
)
self.engine_client = engine_client self.engine_client = engine_client
self.base_model_paths = base_model_paths self.base_model_paths = base_model_paths
...@@ -79,34 +130,18 @@ class OpenAIServingModels: ...@@ -79,34 +130,18 @@ class OpenAIServingModels:
if isinstance(load_result, ErrorResponse): if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.error.message) raise ValueError(load_result.error.message)
def is_base_model(self, model_name) -> bool: def is_base_model(self, model_name: str) -> bool:
return any(model.name == model_name for model in self.base_model_paths) return self.registry.is_base_model(model_name)
def model_name(self, lora_request: LoRARequest | None = None) -> str: def model_name(self, lora_request: LoRARequest | None = None) -> str:
"""Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora_request is not None: if lora_request is not None:
return lora_request.lora_name return lora_request.lora_name
return self.base_model_paths[0].name return self.base_model_paths[0].name
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. This includes the base model and all adapters.""" """Show available models. This includes the base model and all
max_model_len = self.model_config.max_model_len adapters."""
model_list = await self.registry.show_available_models()
model_cards = [
ModelCard(
id=base_model.name,
max_model_len=max_model_len,
root=base_model.model_path,
permission=[ModelPermission()],
)
for base_model in self.base_model_paths
]
lora_cards = [ lora_cards = [
ModelCard( ModelCard(
id=lora.lora_name, id=lora.lora_name,
...@@ -118,8 +153,8 @@ class OpenAIServingModels: ...@@ -118,8 +153,8 @@ class OpenAIServingModels:
) )
for lora in self.lora_requests.values() for lora in self.lora_requests.values()
] ]
model_cards.extend(lora_cards) model_list.data.extend(lora_cards)
return ModelList(data=model_cards) return model_list
async def load_lora_adapter( async def load_lora_adapter(
self, request: LoadLoRAAdapterRequest, base_model_name: str | None = None self, request: LoadLoRAAdapterRequest, base_model_name: str | None = None
......
...@@ -16,10 +16,8 @@ from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionReque ...@@ -16,10 +16,8 @@ from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionReque
from vllm.entrypoints.openai.completion.protocol import CompletionRequest from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse, ErrorResponse,
ModelCard,
ModelList,
ModelPermission,
) )
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
from vllm.entrypoints.openai.parser.harmony_utils import ( from vllm.entrypoints.openai.parser.harmony_utils import (
get_developer_message, get_developer_message,
get_system_message, get_system_message,
...@@ -46,7 +44,7 @@ class OpenAIServingRender: ...@@ -46,7 +44,7 @@ class OpenAIServingRender:
model_config: ModelConfig, model_config: ModelConfig,
renderer: BaseRenderer, renderer: BaseRenderer,
io_processor: Any, io_processor: Any,
served_model_names: list[str], model_registry: OpenAIModelRegistry,
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
chat_template: str | None, chat_template: str | None,
...@@ -61,7 +59,7 @@ class OpenAIServingRender: ...@@ -61,7 +59,7 @@ class OpenAIServingRender:
self.model_config = model_config self.model_config = model_config
self.renderer = renderer self.renderer = renderer
self.io_processor = io_processor self.io_processor = io_processor
self.served_model_names = served_model_names self.model_registry = model_registry
self.request_logger = request_logger self.request_logger = request_logger
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: ChatTemplateContentFormatOption = ( self.chat_template_content_format: ChatTemplateContentFormatOption = (
...@@ -252,21 +250,6 @@ class OpenAIServingRender: ...@@ -252,21 +250,6 @@ class OpenAIServingRender:
return messages, [engine_prompt] return messages, [engine_prompt]
async def show_available_models(self) -> ModelList:
"""Returns the models served by this render server."""
max_model_len = self.model_config.max_model_len
return ModelList(
data=[
ModelCard(
id=name,
max_model_len=max_model_len,
root=self.model_config.model,
permission=[ModelPermission()],
)
for name in self.served_model_names
]
)
def create_error_response( def create_error_response(
self, self,
message: str | Exception, message: str | Exception,
...@@ -276,23 +259,11 @@ class OpenAIServingRender: ...@@ -276,23 +259,11 @@ class OpenAIServingRender:
) -> ErrorResponse: ) -> ErrorResponse:
return create_error_response(message, err_type, status_code, param) return create_error_response(message, err_type, status_code, param)
def _is_model_supported(self, model_name: str) -> bool:
"""Simplified from OpenAIServing._is_model_supported (no LoRA support)."""
return model_name in self.served_model_names
async def _check_model( async def _check_model(
self, self,
request: Any, request: Any,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Simplified from OpenAIServing._check_model (no LoRA support).""" return await self.model_registry.check_model(request.model)
if self._is_model_supported(request.model):
return None
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
param="model",
)
def _validate_chat_template( def _validate_chat_template(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment