Unverified Commit 980ad394 authored by Joe Runde's avatar Joe Runde Committed by GitHub
Browse files

[Frontend] Use request id from header (#10968)


Signed-off-by: default avatarJoe Runde <Joseph.Runde@ibm.com>
parent 391d7b27
...@@ -16,5 +16,6 @@ mistral_common >= 1.5.0 ...@@ -16,5 +16,6 @@ mistral_common >= 1.5.0
aiohttp aiohttp
starlette starlette
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
requests requests
...@@ -305,7 +305,7 @@ async def health(raw_request: Request) -> Response: ...@@ -305,7 +305,7 @@ async def health(raw_request: Request) -> Response:
async def tokenize(request: TokenizeRequest, raw_request: Request): async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request) handler = tokenization(raw_request)
generator = await handler.create_tokenize(request) generator = await handler.create_tokenize(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
...@@ -319,7 +319,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): ...@@ -319,7 +319,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
async def detokenize(request: DetokenizeRequest, raw_request: Request): async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request) handler = tokenization(raw_request)
generator = await handler.create_detokenize(request) generator = await handler.create_detokenize(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
......
...@@ -176,7 +176,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -176,7 +176,8 @@ class OpenAIServingChat(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
request_id = f"chatcmpl-{request.request_id}" request_id = "chatcmpl-" \
f"{self._base_request_id(raw_request, request.request_id)}"
request_metadata = RequestResponseMetadata(request_id=request_id) request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request: if raw_request:
......
...@@ -30,7 +30,7 @@ from vllm.outputs import RequestOutput ...@@ -30,7 +30,7 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -86,7 +86,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -86,7 +86,7 @@ class OpenAIServingCompletion(OpenAIServing):
"suffix is not currently supported") "suffix is not currently supported")
model_name = self.base_model_paths[0].name model_name = self.base_model_paths[0].name
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{self._base_request_id(raw_request)}"
created_time = int(time.time()) created_time = int(time.time())
request_metadata = RequestResponseMetadata(request_id=request_id) request_metadata = RequestResponseMetadata(request_id=request_id)
......
...@@ -19,7 +19,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, ...@@ -19,7 +19,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -110,7 +110,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -110,7 +110,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"dimensions is currently not supported") "dimensions is currently not supported")
model_name = request.model model_name = request.model
request_id = f"embd-{random_uuid()}" request_id = f"embd-{self._base_request_id(raw_request)}"
created_time = int(time.monotonic()) created_time = int(time.monotonic())
truncate_prompt_tokens = None truncate_prompt_tokens = None
......
...@@ -6,6 +6,7 @@ from http import HTTPStatus ...@@ -6,6 +6,7 @@ from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping, from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Optional, Sequence, Tuple, TypedDict, Union) Optional, Sequence, Tuple, TypedDict, Union)
from fastapi import Request
from pydantic import Field from pydantic import Field
from starlette.datastructures import Headers from starlette.datastructures import Headers
from typing_extensions import Annotated from typing_extensions import Annotated
...@@ -47,7 +48,7 @@ from vllm.sequence import Logprob ...@@ -47,7 +48,7 @@ from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of, make_async from vllm.utils import AtomicCounter, is_list_of, make_async, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -565,6 +566,14 @@ class OpenAIServing: ...@@ -565,6 +566,14 @@ class OpenAIServing:
return None return None
@staticmethod
def _base_request_id(raw_request: Request,
default: Optional[str] = None) -> Optional[str]:
"""Pulls the request id to use from a header, if provided"""
default = default or random_uuid()
return raw_request.headers.get(
"X-Request-Id", default) if raw_request is not None else default
@staticmethod @staticmethod
def _get_decoded_token(logprob: Logprob, def _get_decoded_token(logprob: Logprob,
token_id: int, token_id: int,
......
...@@ -15,7 +15,7 @@ from vllm.inputs.data import TokensPrompt ...@@ -15,7 +15,7 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import make_async, merge_async_iterators, random_uuid from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -102,7 +102,7 @@ class OpenAIServingScores(OpenAIServing): ...@@ -102,7 +102,7 @@ class OpenAIServingScores(OpenAIServing):
return error_check_ret return error_check_ret
model_name = request.model model_name = request.model
request_id = f"score-{random_uuid()}" request_id = f"score-{self._base_request_id(raw_request)}"
created_time = int(time.monotonic()) created_time = int(time.monotonic())
truncate_prompt_tokens = request.truncate_prompt_tokens truncate_prompt_tokens = request.truncate_prompt_tokens
......
from typing import Final, List, Optional, Union from typing import Final, List, Optional, Union
from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
...@@ -17,7 +19,6 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath, ...@@ -17,7 +19,6 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath, LoRAModulePath,
OpenAIServing) OpenAIServing)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -48,12 +49,13 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -48,12 +49,13 @@ class OpenAIServingTokenization(OpenAIServing):
async def create_tokenize( async def create_tokenize(
self, self,
request: TokenizeRequest, request: TokenizeRequest,
raw_request: Request,
) -> Union[TokenizeResponse, ErrorResponse]: ) -> Union[TokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
request_id = f"tokn-{random_uuid()}" request_id = f"tokn-{self._base_request_id(raw_request)}"
try: try:
( (
...@@ -112,12 +114,13 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -112,12 +114,13 @@ class OpenAIServingTokenization(OpenAIServing):
async def create_detokenize( async def create_detokenize(
self, self,
request: DetokenizeRequest, request: DetokenizeRequest,
raw_request: Request,
) -> Union[DetokenizeResponse, ErrorResponse]: ) -> Union[DetokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
request_id = f"tokn-{random_uuid()}" request_id = f"tokn-{self._base_request_id(raw_request)}"
( (
lora_request, lora_request,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment