Unverified Commit 980ad394 authored by Joe Runde's avatar Joe Runde Committed by GitHub
Browse files

[Frontend] Use request id from header (#10968)


Signed-off-by: default avatarJoe Runde <Joseph.Runde@ibm.com>
parent 391d7b27
......@@ -16,5 +16,6 @@ mistral_common >= 1.5.0
aiohttp
starlette
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
requests
......@@ -305,7 +305,7 @@ async def health(raw_request: Request) -> Response:
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
generator = await handler.create_tokenize(request)
generator = await handler.create_tokenize(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
......@@ -319,7 +319,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
generator = await handler.create_detokenize(request)
generator = await handler.create_detokenize(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
......
......@@ -176,7 +176,8 @@ class OpenAIServingChat(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
request_id = f"chatcmpl-{request.request_id}"
request_id = "chatcmpl-" \
f"{self._base_request_id(raw_request, request.request_id)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
......
......@@ -30,7 +30,7 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
......@@ -86,7 +86,7 @@ class OpenAIServingCompletion(OpenAIServing):
"suffix is not currently supported")
model_name = self.base_model_paths[0].name
request_id = f"cmpl-{random_uuid()}"
request_id = f"cmpl-{self._base_request_id(raw_request)}"
created_time = int(time.time())
request_metadata = RequestResponseMetadata(request_id=request_id)
......
......@@ -19,7 +19,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
......@@ -110,7 +110,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"dimensions is currently not supported")
model_name = request.model
request_id = f"embd-{random_uuid()}"
request_id = f"embd-{self._base_request_id(raw_request)}"
created_time = int(time.monotonic())
truncate_prompt_tokens = None
......
......@@ -6,6 +6,7 @@ from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Optional, Sequence, Tuple, TypedDict, Union)
from fastapi import Request
from pydantic import Field
from starlette.datastructures import Headers
from typing_extensions import Annotated
......@@ -47,7 +48,7 @@ from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of, make_async
from vllm.utils import AtomicCounter, is_list_of, make_async, random_uuid
logger = init_logger(__name__)
......@@ -565,6 +566,14 @@ class OpenAIServing:
return None
@staticmethod
def _base_request_id(raw_request: Request,
default: Optional[str] = None) -> Optional[str]:
"""Pulls the request id to use from a header, if provided"""
default = default or random_uuid()
return raw_request.headers.get(
"X-Request-Id", default) if raw_request is not None else default
@staticmethod
def _get_decoded_token(logprob: Logprob,
token_id: int,
......
......@@ -15,7 +15,7 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import make_async, merge_async_iterators, random_uuid
from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__)
......@@ -102,7 +102,7 @@ class OpenAIServingScores(OpenAIServing):
return error_check_ret
model_name = request.model
request_id = f"score-{random_uuid()}"
request_id = f"score-{self._base_request_id(raw_request)}"
created_time = int(time.monotonic())
truncate_prompt_tokens = request.truncate_prompt_tokens
......
from typing import Final, List, Optional, Union
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
......@@ -17,7 +19,6 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
......@@ -48,12 +49,13 @@ class OpenAIServingTokenization(OpenAIServing):
async def create_tokenize(
self,
request: TokenizeRequest,
raw_request: Request,
) -> Union[TokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"tokn-{random_uuid()}"
request_id = f"tokn-{self._base_request_id(raw_request)}"
try:
(
......@@ -112,12 +114,13 @@ class OpenAIServingTokenization(OpenAIServing):
async def create_detokenize(
self,
request: DetokenizeRequest,
raw_request: Request,
) -> Union[DetokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"tokn-{random_uuid()}"
request_id = f"tokn-{self._base_request_id(raw_request)}"
(
lora_request,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment