"vllm/vscode:/vscode.git/clone" did not exist on "3b61cb450d899dc423feb264c297d4d18d701678"
Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
...@@ -43,10 +43,10 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( ...@@ -43,10 +43,10 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
truncate_tool_call_ids, truncate_tool_call_ids,
...@@ -186,7 +186,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -186,7 +186,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request = self._maybe_get_adapters( lora_request = self._maybe_get_adapters(
request, supports_default_mm_loras=True) request, supports_default_mm_loras=True)
model_name = self._get_model_name(request.model, lora_request) model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
...@@ -489,6 +489,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -489,6 +489,8 @@ class OpenAIServingChat(OpenAIServing):
get_streamable_parser_for_assistant() get_streamable_parser_for_assistant()
for _ in range(num_choices) for _ in range(num_choices)
] ]
harmony_tools_streamed = [False] * num_choices
tools_streamed = [False] * num_choices
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name tool_choice_function_name = request.tool_choice.function.name
...@@ -662,13 +664,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -662,13 +664,11 @@ class OpenAIServingChat(OpenAIServing):
if self.use_harmony: if self.use_harmony:
harmony_parser = harmony_parsers[i] harmony_parser = harmony_parsers[i]
prev_recipient = harmony_parser.current_recipient
for token_id in output.token_ids: for token_id in output.token_ids:
harmony_parser.process(token_id) harmony_parser.process(token_id)
is_reasoning = \ cur_channel = harmony_parser.current_channel
harmony_parser.current_channel == "analysis" cur_recipient = harmony_parser.current_recipient
if not request.include_reasoning and is_reasoning:
# Skip the reasoning content.
continue
delta_text = harmony_parser.last_content_delta or "" delta_text = harmony_parser.last_content_delta or ""
else: else:
delta_text = output.text delta_text = output.text
...@@ -681,8 +681,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -681,8 +681,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message: Optional[DeltaMessage] delta_message: Optional[DeltaMessage]
# just update previous_texts and previous_token_ids # just update previous_texts and previous_token_ids
if ((tool_choice_auto or self.reasoning_parser) if tool_choice_auto or self.reasoning_parser:
and not self.use_harmony):
assert previous_texts is not None assert previous_texts is not None
assert all_previous_token_ids is not None assert all_previous_token_ids is not None
previous_text = previous_texts[i] previous_text = previous_texts[i]
...@@ -696,11 +695,54 @@ class OpenAIServingChat(OpenAIServing): ...@@ -696,11 +695,54 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids = as_list(output.token_ids) current_token_ids = as_list(output.token_ids)
if self.use_harmony: if self.use_harmony:
if is_reasoning: if cur_channel == "final":
delta_message = DeltaMessage(
reasoning_content=delta_text)
else:
delta_message = DeltaMessage(content=delta_text) delta_message = DeltaMessage(content=delta_text)
elif cur_channel == "analysis":
if request.include_reasoning:
delta_message = DeltaMessage(
reasoning_content=delta_text)
else:
delta_message = None
elif (cur_channel == "commentary" and cur_recipient
and cur_recipient.startswith("functions.")):
# Count completed tool calls to determine index
base_index = 0
for msg in harmony_parser.messages:
if (msg.channel == "commentary"
and msg.recipient
and msg.recipient.startswith(
"functions.")):
base_index += 1
if prev_recipient != cur_recipient:
tool_name = cur_recipient.split(
"functions.", 1)[1]
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(
id=make_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=tool_name,
arguments="",
),
index=base_index,
)
])
elif delta_text:
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(
index=base_index,
function=DeltaFunctionCall(
arguments=delta_text),
)
])
else:
delta_message = None
if delta_message is not None:
harmony_tools_streamed[i] = True
else:
delta_message = None
# handle streaming deltas for tools with named tool_choice # handle streaming deltas for tools with named tool_choice
elif tool_choice_function_name: elif tool_choice_function_name:
if (self.reasoning_parser and not reasoning_end_arr[i] if (self.reasoning_parser and not reasoning_end_arr[i]
...@@ -758,6 +800,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -758,6 +800,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message = DeltaMessage(tool_calls=[ delta_message = DeltaMessage(tool_calls=[
delta_tool_call, delta_tool_call,
]) ])
tools_streamed[i] = True
elif request.tool_choice == "required": elif request.tool_choice == "required":
assert previous_texts is not None assert previous_texts is not None
...@@ -783,6 +826,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -783,6 +826,7 @@ class OpenAIServingChat(OpenAIServing):
if (delta_message and delta_message.tool_calls and if (delta_message and delta_message.tool_calls and
delta_message.tool_calls[0].id is not None): delta_message.tool_calls[0].id is not None):
history_tool_call_cnt += 1 history_tool_call_cnt += 1
tools_streamed[i] = True
# update the previous values for the next iteration # update the previous values for the next iteration
previous_texts[i] = current_text previous_texts[i] = current_text
...@@ -859,6 +903,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -859,6 +903,8 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids=current_token_ids, current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids, delta_token_ids=delta_token_ids,
request=request)) request=request))
if delta_message and delta_message.tool_calls:
tools_streamed[i] = True
# when only tool calls # when only tool calls
elif tool_choice_auto: elif tool_choice_auto:
assert tool_parser is not None assert tool_parser is not None
...@@ -871,6 +917,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -871,6 +917,8 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids=current_token_ids, current_token_ids=current_token_ids,
delta_token_ids=output.token_ids, delta_token_ids=output.token_ids,
request=request)) request=request))
if delta_message and delta_message.tool_calls:
tools_streamed[i] = True
# when only reasoning # when only reasoning
elif self.reasoning_parser: elif self.reasoning_parser:
...@@ -907,7 +955,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -907,7 +955,10 @@ class OpenAIServingChat(OpenAIServing):
# wasn't ready to send a token, then # wasn't ready to send a token, then
# get the next token without streaming a chunk # get the next token without streaming a chunk
if delta_message is None: if delta_message is None:
continue if output.finish_reason is None:
continue
else:
delta_message = DeltaMessage()
# Log streaming delta if output logging is enabled # Log streaming delta if output logging is enabled
if self.enable_log_outputs and self.request_logger: if self.enable_log_outputs and self.request_logger:
...@@ -993,12 +1044,18 @@ class OpenAIServingChat(OpenAIServing): ...@@ -993,12 +1044,18 @@ class OpenAIServingChat(OpenAIServing):
]) ])
# Send the finish response for each request.n only once # Send the finish response for each request.n only once
if auto_tools_called or tools_streamed[i] or (
self.use_harmony
and harmony_tools_streamed[i]):
finish_reason_ = "tool_calls"
else:
finish_reason_ = output.finish_reason \
if output.finish_reason else "stop"
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=delta_message, delta=delta_message,
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason finish_reason=finish_reason_,
if not auto_tools_called else "tool_calls",
stop_reason=output.stop_reason, stop_reason=output.stop_reason,
token_ids=(as_list(output.token_ids) token_ids=(as_list(output.token_ids)
if request.return_token_ids else None)) if request.return_token_ids else None))
...@@ -1117,6 +1174,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1117,6 +1174,7 @@ class OpenAIServingChat(OpenAIServing):
for output in final_res.outputs: for output in final_res.outputs:
token_ids = output.token_ids token_ids = output.token_ids
out_logprobs = output.logprobs out_logprobs = output.logprobs
tool_call_info = None
if request.logprobs and request.top_logprobs is not None: if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs" assert out_logprobs is not None, "Did not output logprobs"
...@@ -1131,31 +1189,42 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1131,31 +1189,42 @@ class OpenAIServingChat(OpenAIServing):
logprobs = None logprobs = None
if self.use_harmony: if self.use_harmony:
reasoning_content, final_content, is_tool_call = ( if self.tool_parser is not None:
parse_chat_output(token_ids)) tool_parser = self.tool_parser(tokenizer)
if not request.include_reasoning: # NOTE: We use token_ids for openai tool parser
reasoning_content = None tool_call_info = tool_parser.extract_tool_calls(
"",
if is_tool_call: request=request,
# TODO(woosuk): Implement tool call for gpt-oss. token_ids=token_ids, # type: ignore
# For now, only Responses API supports tool call for )
# gpt-oss. reasoning_content, content = None, tool_call_info.content
raise NotImplementedError( if request.include_reasoning:
"Tool call in Chat Completion API is not supported " reasoning_content, content, _ = parse_chat_output(
"for gpt-oss yet. Please use Responses API instead.") token_ids)
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content=content,
tool_calls=tool_call_info.tool_calls,
)
else: else:
# Normal message reasoning_content, content, _ = parse_chat_output(
token_ids)
if not request.include_reasoning:
reasoning_content = None
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
content=final_content, content=content,
) )
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=output.index, index=output.index,
message=message, message=message,
logprobs=logprobs, logprobs=logprobs,
finish_reason="tool_calls" if is_tool_call else finish_reason="tool_calls" if
(tool_call_info is not None
and tool_call_info.tools_called) else
output.finish_reason if output.finish_reason else "stop", output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason, stop_reason=output.stop_reason,
) )
...@@ -1419,9 +1488,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1419,9 +1488,10 @@ class OpenAIServingChat(OpenAIServing):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None or step_top_logprobs.get( if step_top_logprobs is None or step_top_logprobs.get(
token_id) is None: token_id) is None:
token = tokenizer.decode(token_id)
if should_return_as_token_id: if should_return_as_token_id:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
else:
token = tokenizer.decode(token_id)
logprobs_content.append( logprobs_content.append(
ChatCompletionLogProbsContent( ChatCompletionLogProbsContent(
...@@ -1503,12 +1573,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1503,12 +1573,12 @@ class OpenAIServingChat(OpenAIServing):
messages.append(sys_msg) messages.append(sys_msg)
# Add developer message. # Add developer message.
dev_msg = get_developer_message() dev_msg = get_developer_message(tools=request.tools)
messages.append(dev_msg) messages.append(dev_msg)
# Add user message. # Add user message.
for chat_msg in request.messages: for chat_msg in request.messages:
messages.append(parse_chat_input(chat_msg)) messages.extend(parse_chat_input(chat_msg))
# Render prompt token ids. # Render prompt token ids.
prompt_token_ids = render_for_completion(messages) prompt_token_ids = render_for_completion(messages)
......
...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext, ...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
OpenAIServing, OpenAIServing,
ServeContext) ServeContext)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput from vllm.outputs import ClassificationOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -54,14 +55,10 @@ class ClassificationMixin(OpenAIServing): ...@@ -54,14 +55,10 @@ class ClassificationMixin(OpenAIServing):
ctx.tokenizer = await self.engine_client.get_tokenizer( ctx.tokenizer = await self.engine_client.get_tokenizer(
ctx.lora_request) ctx.lora_request)
( renderer = self._get_renderer(ctx.tokenizer)
ctx.request_prompts, ctx.engine_prompts = await renderer.render_prompt(
ctx.engine_prompts, prompt_or_prompts=ctx.request.input,
) = await self._preprocess_completion( config=self._build_render_config(ctx.request))
ctx.request,
ctx.tokenizer,
ctx.request.input,
)
return None return None
...@@ -117,6 +114,12 @@ class ClassificationMixin(OpenAIServing): ...@@ -117,6 +114,12 @@ class ClassificationMixin(OpenAIServing):
usage=usage, usage=usage,
) )
def _build_render_config(self,
request: ClassificationRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens)
class ServingClassification(ClassificationMixin): class ServingClassification(ClassificationMixin):
request_id_prefix = "classify" request_id_prefix = "classify"
...@@ -143,7 +146,7 @@ class ServingClassification(ClassificationMixin): ...@@ -143,7 +146,7 @@ class ServingClassification(ClassificationMixin):
request: ClassificationRequest, request: ClassificationRequest,
raw_request: Request, raw_request: Request,
) -> Union[ClassificationResponse, ErrorResponse]: ) -> Union[ClassificationResponse, ErrorResponse]:
model_name = self._get_model_name(request.model) model_name = self.models.model_name()
request_id = (f"{self.request_id_prefix}-" request_id = (f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request)}") f"{self._base_request_id(raw_request)}")
......
...@@ -26,21 +26,18 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, ...@@ -26,21 +26,18 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
PromptTokenUsageInfo, PromptTokenUsageInfo,
RequestResponseMetadata, RequestResponseMetadata,
UsageInfo) UsageInfo)
from vllm.entrypoints.openai.serving_engine import (
EmbedsPrompt as ServingEngineEmbedsPrompt)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing, from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
TextTokensPrompt, clamp_prompt_logprobs)
clamp_prompt_logprobs,
is_text_tokens_prompt)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt) is_tokens_prompt)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import as_list, merge_async_iterators from vllm.utils import as_list, merge_async_iterators
...@@ -132,12 +129,12 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -132,12 +129,12 @@ class OpenAIServingCompletion(OpenAIServing):
else: else:
tokenizer = await self.engine_client.get_tokenizer(lora_request tokenizer = await self.engine_client.get_tokenizer(lora_request
) )
renderer = self._get_renderer(tokenizer)
request_prompts, engine_prompts = await self._preprocess_completion( engine_prompts = await renderer.render_prompt_and_embeds(
request, prompt_or_prompts=request.prompt,
tokenizer, prompt_embeds=request.prompt_embeds,
request.prompt, config=self._build_render_config(request),
add_special_tokens=request.add_special_tokens,
) )
except ValueError as e: except ValueError as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
...@@ -198,7 +195,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -198,7 +195,7 @@ class OpenAIServingCompletion(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
request_prompts[i], engine_prompt,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -235,7 +232,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -235,7 +232,7 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators(*generators) result_generator = merge_async_iterators(*generators)
model_name = self._get_model_name(request.model, lora_request) model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
...@@ -249,7 +246,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -249,7 +246,7 @@ class OpenAIServingCompletion(OpenAIServing):
if stream: if stream:
return self.completion_stream_generator( return self.completion_stream_generator(
request, request,
request_prompts, engine_prompts,
result_generator, result_generator,
request_id, request_id,
created_time, created_time,
...@@ -273,11 +270,9 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -273,11 +270,9 @@ class OpenAIServingCompletion(OpenAIServing):
# We did not pass it into vLLM engine to avoid being redundant # We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs # with the inputs token IDs
if final_res.prompt is None: if final_res.prompt is None:
request_prompt = request_prompts[i] engine_prompt = engine_prompts[i]
if is_text_tokens_prompt(request_prompt): final_res.prompt = None if is_embeds_prompt(
final_res.prompt = request_prompt["prompt"] engine_prompt) else engine_prompt.get("prompt")
else:
final_res.prompt = None
final_res_batch_checked = cast(list[RequestOutput], final_res_batch_checked = cast(list[RequestOutput],
final_res_batch) final_res_batch)
...@@ -313,8 +308,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -313,8 +308,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator( async def completion_stream_generator(
self, self,
request: CompletionRequest, request: CompletionRequest,
request_prompts: list[Union[TextTokensPrompt, engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]],
ServingEngineEmbedsPrompt]],
result_generator: AsyncIterator[tuple[int, RequestOutput]], result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str, request_id: str,
created_time: int, created_time: int,
...@@ -350,14 +344,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -350,14 +344,11 @@ class OpenAIServingCompletion(OpenAIServing):
num_cached_tokens = res.num_cached_tokens num_cached_tokens = res.num_cached_tokens
first_iteration = False first_iteration = False
if res.prompt is not None: prompt_text = res.prompt
prompt_text = res.prompt if prompt_text is None:
else: engine_prompt = engine_prompts[prompt_idx]
request_prompt = request_prompts[prompt_idx] prompt_text = None if is_embeds_prompt(
if is_text_tokens_prompt(request_prompt): engine_prompt) else engine_prompt.get("prompt")
prompt_text = request_prompt["prompt"]
else:
prompt_text = None
# Prompt details are excluded from later streamed outputs # Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None: if prompt_token_ids is not None:
...@@ -378,6 +369,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -378,6 +369,8 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and not has_echoed[i]: if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None assert prompt_token_ids is not None
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None assert prompt_text is not None
if request.max_tokens == 0: if request.max_tokens == 0:
# only return the prompt # only return the prompt
...@@ -525,6 +518,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -525,6 +518,8 @@ class OpenAIServingCompletion(OpenAIServing):
for output in final_res.outputs: for output in final_res.outputs:
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo: if request.echo:
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None assert prompt_text is not None
if request.max_tokens == 0: if request.max_tokens == 0:
token_ids = prompt_token_ids token_ids = prompt_token_ids
...@@ -676,3 +671,18 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -676,3 +671,18 @@ class OpenAIServingCompletion(OpenAIServing):
tokens=out_tokens, tokens=out_tokens,
top_logprobs=out_top_logprobs, top_logprobs=out_top_logprobs,
) )
def _build_render_config(
self,
request: CompletionRequest,
max_input_length: Optional[int] = None,
) -> RenderConfig:
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo
and not request.return_token_ids),
)
...@@ -24,12 +24,11 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, ...@@ -24,12 +24,11 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse, UsageInfo) ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
OpenAIServing, OpenAIServing,
RequestPrompt,
ServeContext, ServeContext,
TextTokensPrompt) TextTokensPrompt)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
...@@ -79,11 +78,12 @@ class EmbeddingMixin(OpenAIServing): ...@@ -79,11 +78,12 @@ class EmbeddingMixin(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
) )
renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
( (
_, _,
ctx.request_prompts, _,
ctx.engine_prompts, ctx.engine_prompts,
) = await self._preprocess_chat( ) = await self._preprocess_chat(
ctx.request, ctx.request,
...@@ -93,25 +93,33 @@ class EmbeddingMixin(OpenAIServing): ...@@ -93,25 +93,33 @@ class EmbeddingMixin(OpenAIServing):
or ctx.chat_template, or ctx.chat_template,
chat_template_content_format=ctx. chat_template_content_format=ctx.
chat_template_content_format, chat_template_content_format,
# In embedding requests, we are not generating tokens, add_generation_prompt=ctx.request.add_generation_prompt,
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False, continue_final_message=False,
add_special_tokens=ctx.request.add_special_tokens, add_special_tokens=ctx.request.add_special_tokens,
) )
else: else:
(ctx.request_prompts, ctx.engine_prompts = await renderer.render_prompt(
ctx.engine_prompts) = await self._preprocess_completion( prompt_or_prompts=ctx.request.input,
ctx.request, config=self._build_render_config(ctx.request),
tokenizer, )
ctx.request.input,
add_special_tokens=ctx.request.add_special_tokens,
)
return None return None
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
def _build_render_config(
self, request: EmbeddingCompletionRequest) -> RenderConfig:
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len
return RenderConfig(
max_length=max_length,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens)
@override @override
def _build_response( def _build_response(
self, self,
...@@ -287,8 +295,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -287,8 +295,7 @@ class EmbeddingMixin(OpenAIServing):
async def _create_single_prompt_generator( async def _create_single_prompt_generator(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt], engine_prompt: EngineTokensPrompt,
request_prompt: RequestPrompt,
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Optional[Mapping[str, str]], trace_headers: Optional[Mapping[str, str]],
prompt_index: int, prompt_index: int,
...@@ -297,16 +304,10 @@ class EmbeddingMixin(OpenAIServing): ...@@ -297,16 +304,10 @@ class EmbeddingMixin(OpenAIServing):
request_id_item = f"{ctx.request_id}-{prompt_index}" request_id_item = f"{ctx.request_id}-{prompt_index}"
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
request_prompt, engine_prompt,
params=pooling_params, params=pooling_params,
lora_request=ctx.lora_request) lora_request=ctx.lora_request)
# Mypy has an existing bug related to inferring the variance
# of TypedDicts with `builtins.enumerate`:
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)
# Return the original generator without wrapping # Return the original generator without wrapping
return self.engine_client.encode( return self.engine_client.encode(
engine_prompt, engine_prompt,
...@@ -355,20 +356,14 @@ class EmbeddingMixin(OpenAIServing): ...@@ -355,20 +356,14 @@ class EmbeddingMixin(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"Engine prompts not available") "Engine prompts not available")
if ctx.request_prompts is None:
return self.create_error_response(
"Request prompts not available")
max_pos_embeddings = self._get_max_position_embeddings() max_pos_embeddings = self._get_max_position_embeddings()
for i, engine_prompt in enumerate(ctx.engine_prompts): for i, engine_prompt in enumerate(ctx.engine_prompts):
request_prompt = ctx.request_prompts[i]
# Check if this specific prompt needs chunked processing # Check if this specific prompt needs chunked processing
if self._is_text_tokens_prompt(request_prompt): if self._is_text_tokens_prompt(engine_prompt):
# Cast to TextTokensPrompt since we've verified # Cast to TextTokensPrompt since we've verified
# prompt_token_ids # prompt_token_ids
text_tokens_prompt = cast(TextTokensPrompt, request_prompt) text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
if (len(text_tokens_prompt["prompt_token_ids"]) if (len(text_tokens_prompt["prompt_token_ids"])
> max_pos_embeddings): > max_pos_embeddings):
# Use chunked processing for this prompt # Use chunked processing for this prompt
...@@ -379,13 +374,8 @@ class EmbeddingMixin(OpenAIServing): ...@@ -379,13 +374,8 @@ class EmbeddingMixin(OpenAIServing):
continue continue
# Normal processing for short prompts or non-token prompts # Normal processing for short prompts or non-token prompts
# Cast engine_prompt to the expected type for mypy
engine_prompt_typed = cast(
Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)
generator = await self._create_single_prompt_generator( generator = await self._create_single_prompt_generator(
ctx, engine_prompt_typed, request_prompt, pooling_params, ctx, engine_prompt, pooling_params, trace_headers, i)
trace_headers, i)
generators.append(generator) generators.append(generator)
from vllm.utils import merge_async_iterators from vllm.utils import merge_async_iterators
...@@ -421,10 +411,6 @@ class EmbeddingMixin(OpenAIServing): ...@@ -421,10 +411,6 @@ class EmbeddingMixin(OpenAIServing):
if not use_chunked: if not use_chunked:
return await super()._collect_batch(ctx=ctx) return await super()._collect_batch(ctx=ctx)
if ctx.request_prompts is None:
return self.create_error_response(
"Request prompts not available")
if ctx.result_generator is None: if ctx.result_generator is None:
return self.create_error_response( return self.create_error_response(
"Result generator not available") "Result generator not available")
...@@ -540,7 +526,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -540,7 +526,7 @@ class EmbeddingMixin(OpenAIServing):
data=final_embedding) data=final_embedding)
# Get original prompt token IDs for this prompt # Get original prompt token IDs for this prompt
original_prompt = ctx.request_prompts[prompt_idx] original_prompt = ctx.engine_prompts[prompt_idx]
if not self._is_text_tokens_prompt(original_prompt): if not self._is_text_tokens_prompt(original_prompt):
return self.create_error_response( return self.create_error_response(
f"Chunked prompt {prompt_idx} is not a " f"Chunked prompt {prompt_idx} is not a "
...@@ -613,7 +599,7 @@ class OpenAIServingEmbedding(EmbeddingMixin): ...@@ -613,7 +599,7 @@ class OpenAIServingEmbedding(EmbeddingMixin):
See https://platform.openai.com/docs/api-reference/embeddings/create See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API. for the API specification. This API mimics the OpenAI Embedding API.
""" """
model_name = self._get_model_name(request.model) model_name = self.models.model_name()
request_id = ( request_id = (
f"{self.request_id_prefix}-" f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request, request.request_id)}") f"{self._base_request_id(raw_request, request.request_id)}")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import json import json
import sys import sys
import time import time
...@@ -9,10 +7,8 @@ import traceback ...@@ -9,10 +7,8 @@ import traceback
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus from http import HTTPStatus
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
TypeVar, Union, cast, overload)
import pybase64
import torch import torch
from fastapi import Request from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
...@@ -62,18 +58,19 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -62,18 +58,19 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TranslationRequest) TranslationRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer,
RenderConfig)
# yapf: enable # yapf: enable
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
MultiModalDataDict, MultiModalUUIDDict) MultiModalDataDict, MultiModalUUIDDict)
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob, PromptLogprobs
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
...@@ -82,16 +79,26 @@ from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of, ...@@ -82,16 +79,26 @@ from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
logger = init_logger(__name__) logger = init_logger(__name__)
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, CompletionLikeRequest = Union[
EmbeddingCompletionRequest, RerankRequest, CompletionRequest,
ClassificationRequest, ScoreRequest, DetokenizeRequest,
TokenizeCompletionRequest] EmbeddingCompletionRequest,
RerankRequest,
ClassificationRequest,
ScoreRequest,
TokenizeCompletionRequest,
]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest] TokenizeChatRequest]
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest, AnyRequest = Union[
ResponsesRequest, IOProcessorRequest] CompletionLikeRequest,
ChatLikeRequest,
SpeechToTextRequest,
ResponsesRequest,
IOProcessorRequest,
]
AnyResponse = Union[ AnyResponse = Union[
CompletionResponse, CompletionResponse,
...@@ -135,9 +142,9 @@ class RequestProcessingMixin(BaseModel): ...@@ -135,9 +142,9 @@ class RequestProcessingMixin(BaseModel):
Mixin for request processing, Mixin for request processing,
handling prompt preparation and engine input. handling prompt preparation and engine input.
""" """
request_prompts: Optional[Sequence[RequestPrompt]] = [] request_prompts: Optional[Sequence[RequestPrompt]] = []
engine_prompts: Optional[Union[list[EngineTokensPrompt], engine_prompts: Optional[list[EngineTokensPrompt]] = []
list[EngineEmbedsPrompt]]] = []
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
...@@ -147,6 +154,7 @@ class ResponseGenerationMixin(BaseModel): ...@@ -147,6 +154,7 @@ class ResponseGenerationMixin(BaseModel):
Mixin for response generation, Mixin for response generation,
managing result generators and final batch results. managing result generators and final batch results.
""" """
result_generator: Optional[AsyncGenerator[tuple[int, Union[ result_generator: Optional[AsyncGenerator[tuple[int, Union[
RequestOutput, PoolingRequestOutput]], None]] = None RequestOutput, PoolingRequestOutput]], None]] = None
final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field( final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
...@@ -155,8 +163,12 @@ class ResponseGenerationMixin(BaseModel): ...@@ -155,8 +163,12 @@ class ResponseGenerationMixin(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, class ServeContext(
Generic[RequestT]): RequestProcessingMixin,
ResponseGenerationMixin,
BaseModel,
Generic[RequestT],
):
# Shared across all requests # Shared across all requests
request: RequestT request: RequestT
raw_request: Optional[Request] = None raw_request: Optional[Request] = None
...@@ -227,6 +239,29 @@ class OpenAIServing: ...@@ -227,6 +239,29 @@ class OpenAIServing:
AsyncMicrobatchTokenizer] = {} AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack self.log_error_stack = log_error_stack
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
"""
Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency.
"""
return CompletionRenderer(
model_config=self.model_config,
tokenizer=tokenizer,
async_tokenizer_pool=self._async_tokenizer_pool)
def _build_render_config(
self,
request: Any,
) -> RenderConfig:
"""
Build and return a `RenderConfig` for an endpoint.
Used by the renderer to control how prompts are prepared
(e.g., tokenization and length handling). Endpoints should
implement this with logic appropriate to their request type.
"""
raise NotImplementedError
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
""" """
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
...@@ -298,8 +333,8 @@ class OpenAIServing: ...@@ -298,8 +333,8 @@ class OpenAIServing:
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
None) None)
if truncate_prompt_tokens is not None and \ if (truncate_prompt_tokens is not None
truncate_prompt_tokens > self.max_model_len: and truncate_prompt_tokens > self.max_model_len):
return self.create_error_response( return self.create_error_response(
"truncate_prompt_tokens value is " "truncate_prompt_tokens value is "
"greater than max_model_len." "greater than max_model_len."
...@@ -340,21 +375,13 @@ class OpenAIServing: ...@@ -340,21 +375,13 @@ class OpenAIServing:
for i, engine_prompt in enumerate(ctx.engine_prompts): for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}" request_id_item = f"{ctx.request_id}-{i}"
if ctx.request_prompts is None: self._log_inputs(
return self.create_error_response( request_id_item,
"Request prompts not available") engine_prompt,
params=pooling_params,
self._log_inputs(request_id_item, lora_request=ctx.lora_request,
ctx.request_prompts[i], )
params=pooling_params,
lora_request=ctx.lora_request)
# Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`:
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
engine_prompt = cast(
Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, engine_prompt,
pooling_params, pooling_params,
...@@ -410,10 +437,11 @@ class OpenAIServing: ...@@ -410,10 +437,11 @@ class OpenAIServing:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
def create_error_response( def create_error_response(
self, self,
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> ErrorResponse:
if self.log_error_stack: if self.log_error_stack:
exc_type, _, _ = sys.exc_info() exc_type, _, _ = sys.exc_info()
if exc_type is not None: if exc_type is not None:
...@@ -424,10 +452,11 @@ class OpenAIServing: ...@@ -424,10 +452,11 @@ class OpenAIServing:
message=message, type=err_type, code=status_code.value)) message=message, type=err_type, code=status_code.value))
def create_streaming_error_response( def create_streaming_error_response(
self, self,
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> str:
json_str = json.dumps( json_str = json.dumps(
self.create_error_response(message=message, self.create_error_response(message=message,
err_type=err_type, err_type=err_type,
...@@ -438,25 +467,25 @@ class OpenAIServing: ...@@ -438,25 +467,25 @@ class OpenAIServing:
self, self,
request: AnyRequest, request: AnyRequest,
) -> Optional[ErrorResponse]: ) -> Optional[ErrorResponse]:
error_response = None error_response = None
if self._is_model_supported(request.model): if self._is_model_supported(request.model):
return None return None
if request.model in self.models.lora_requests: if request.model in self.models.lora_requests:
return None return None
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and ( if (envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and
load_result := await self.models.resolve_lora(request.model)): (load_result := await self.models.resolve_lora(request.model))):
if isinstance(load_result, LoRARequest): if isinstance(load_result, LoRARequest):
return None return None
if isinstance(load_result, ErrorResponse) and \ if (isinstance(load_result, ErrorResponse) and
load_result.error.code == HTTPStatus.BAD_REQUEST.value: load_result.error.code == HTTPStatus.BAD_REQUEST.value):
error_response = load_result error_response = load_result
return error_response or self.create_error_response( return error_response or self.create_error_response(
message=f"The model `{request.model}` does not exist.", message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError", err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND,
)
def _get_active_default_mm_loras( def _get_active_default_mm_loras(
self, request: AnyRequest) -> Optional[LoRARequest]: self, request: AnyRequest) -> Optional[LoRARequest]:
...@@ -487,7 +516,6 @@ class OpenAIServing: ...@@ -487,7 +516,6 @@ class OpenAIServing:
request: AnyRequest, request: AnyRequest,
supports_default_mm_loras: bool = False, supports_default_mm_loras: bool = False,
) -> Optional[LoRARequest]: ) -> Optional[LoRARequest]:
if request.model in self.models.lora_requests: if request.model in self.models.lora_requests:
return self.models.lora_requests[request.model] return self.models.lora_requests[request.model]
...@@ -548,13 +576,15 @@ class OpenAIServing: ...@@ -548,13 +576,15 @@ class OpenAIServing:
prompt, prompt,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
truncation=True, truncation=True,
max_length=self.max_model_len) max_length=self.max_model_len,
)
else: else:
encoded = await async_tokenizer( encoded = await async_tokenizer(
prompt, prompt,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
truncation=True, truncation=True,
max_length=truncate_prompt_tokens) max_length=truncate_prompt_tokens,
)
input_ids = encoded.input_ids input_ids = encoded.input_ids
input_text = prompt input_text = prompt
...@@ -595,16 +625,22 @@ class OpenAIServing: ...@@ -595,16 +625,22 @@ class OpenAIServing:
# Note: EmbeddingRequest, ClassificationRequest, # Note: EmbeddingRequest, ClassificationRequest,
# and ScoreRequest doesn't have max_tokens # and ScoreRequest doesn't have max_tokens
if isinstance(request, if isinstance(
(EmbeddingChatRequest, EmbeddingCompletionRequest, request,
ScoreRequest, RerankRequest, ClassificationRequest)): (
EmbeddingChatRequest,
EmbeddingCompletionRequest,
ScoreRequest,
RerankRequest,
ClassificationRequest,
),
):
# Note: input length can be up to the entire model context length # Note: input length can be up to the entire model context length
# since these requests don't generate tokens. # since these requests don't generate tokens.
if token_num > self.max_model_len: if token_num > self.max_model_len:
operations: dict[type[AnyRequest], str] = { operations: dict[type[AnyRequest], str] = {
ScoreRequest: "score", ScoreRequest: "score",
ClassificationRequest: "classification" ClassificationRequest: "classification",
} }
operation = operations.get(type(request), operation = operations.get(type(request),
"embedding generation") "embedding generation")
...@@ -618,8 +654,11 @@ class OpenAIServing: ...@@ -618,8 +654,11 @@ class OpenAIServing:
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation # and does not require model context length validation
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, if isinstance(
DetokenizeRequest)): request,
(TokenizeCompletionRequest, TokenizeChatRequest,
DetokenizeRequest),
):
return TextTokensPrompt(prompt=input_text, return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids) prompt_token_ids=input_ids)
...@@ -639,8 +678,8 @@ class OpenAIServing: ...@@ -639,8 +678,8 @@ class OpenAIServing:
f"{token_num} input tokens. Please reduce the length of " f"{token_num} input tokens. Please reduce the length of "
"the input messages.") "the input messages.")
if max_tokens is not None and \ if (max_tokens is not None
token_num + max_tokens > self.max_model_len: and token_num + max_tokens > self.max_model_len):
raise ValueError( raise ValueError(
"'max_tokens' or 'max_completion_tokens' is too large: " "'max_tokens' or 'max_completion_tokens' is too large: "
f"{max_tokens}. This model's maximum context length is " f"{max_tokens}. This model's maximum context length is "
...@@ -698,156 +737,6 @@ class OpenAIServing: ...@@ -698,156 +737,6 @@ class OpenAIServing:
tokenizer=tokenizer, tokenizer=tokenizer,
) )
async def _tokenize_prompt_input_or_inputs_async(
self,
request: AnyRequest,
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
add_special_tokens: bool = True,
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
inputs_embeds = list[EmbedsPrompt]()
inputs_text = list[TextTokensPrompt]()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
if (truncate_prompt_tokens or 0) < 0:
truncate_prompt_tokens = self.max_model_len
if (isinstance(request, CompletionRequest)
and request.prompt_embeds is not None):
inputs_embeds.extend(
self._load_prompt_embeds(request.prompt_embeds,
truncate_prompt_tokens))
# Empty prompts are okay as long as there are prompt embeddings
if input_or_inputs is None or (inputs_embeds
and input_or_inputs == ""):
return [], inputs_embeds
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is False" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
# Parse and batch the input prompts
batch_inputs = parse_and_batch_prompt(input_or_inputs)
# Process each input in the batch concurrently
tasks = []
for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is False:
assert tokenizer is not None, \
"Tokenizer is required for text prompts"
task = self._normalize_prompt_text_to_input(
request,
prompt_input["content"],
tokenizer=tokenizer,
add_special_tokens=add_special_tokens)
else:
task = self._normalize_prompt_tokens_to_input(
request, prompt_input["content"], tokenizer=tokenizer)
tasks.append(task)
# Wait for all tokenization tasks to complete
results = await asyncio.gather(*tasks)
inputs_text.extend(results)
return inputs_text, inputs_embeds
@overload
async def _preprocess_completion(
self,
request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
RerankRequest, ClassificationRequest, ScoreRequest,
TokenizeCompletionRequest],
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
add_special_tokens: bool = ...,
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
...
@overload
async def _preprocess_completion(
self,
request: CompletionRequest,
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
add_special_tokens: bool = ...,
) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
EngineTokensPrompt, EngineEmbedsPrompt]]]:
...
async def _preprocess_completion(
self,
request: CompletionLikeRequest,
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
add_special_tokens: bool = True,
) -> tuple[Union[list[TextTokensPrompt], list[Union[
TextTokensPrompt, EmbedsPrompt]]], Union[
list[EngineTokensPrompt], list[Union[EngineTokensPrompt,
EngineEmbedsPrompt]]]]:
if not isinstance(request,
CompletionRequest) and input_or_inputs is None:
raise ValueError(
"Prompt embeds with non-completion requests is not"
" currently supported.")
(request_prompts_text, request_prompts_embeds
) = await self._tokenize_prompt_input_or_inputs_async(
request,
tokenizer,
input_or_inputs,
add_special_tokens=add_special_tokens,
)
engine_prompts_text = [
EngineTokensPrompt(
prompt_token_ids=request_prompt_text["prompt_token_ids"])
for request_prompt_text in request_prompts_text
]
cache_salt = request.cache_salt if (
hasattr(request, "cache_salt")
and request.cache_salt is not None) else None
if cache_salt:
for prompt_text in engine_prompts_text:
prompt_text["cache_salt"] = cache_salt
# This check is equivalent to simply checking if
# `request_prompts_embeds` is empty, but it's difficult to propagate
# overloads to the private helper functions to enable this check.
# This overload is needed because only TextPrompts are allowed for
# non-completion requests and if we don't add the overload here,
# everywhere this function is used outside of serving_completion will
# need logic asserting that only text prompts are in the request.
if not isinstance(request,
CompletionRequest) and input_or_inputs is not None:
return request_prompts_text, engine_prompts_text
engine_prompts_embeds = [
EngineEmbedsPrompt(
prompt_embeds=request_prompt_embeds["prompt_embeds"])
for request_prompt_embeds in request_prompts_embeds
]
if cache_salt:
for prompt_embed in engine_prompts_embeds:
prompt_embed["cache_salt"] = cache_salt
request_prompts = request_prompts_embeds + request_prompts_text
engine_prompts = engine_prompts_embeds + engine_prompts_text
return request_prompts, engine_prompts
async def _preprocess_chat( async def _preprocess_chat(
self, self,
request: Union[ChatLikeRequest, ResponsesRequest], request: Union[ChatLikeRequest, ResponsesRequest],
...@@ -862,8 +751,11 @@ class OpenAIServing: ...@@ -862,8 +751,11 @@ class OpenAIServing:
chat_template_kwargs: Optional[dict[str, Any]] = None, chat_template_kwargs: Optional[dict[str, Any]] = None,
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], ) -> tuple[
list[EngineTokensPrompt]]: list[ConversationMessage],
Sequence[RequestPrompt],
list[EngineTokensPrompt],
]:
model_config = self.model_config model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
...@@ -873,7 +765,7 @@ class OpenAIServing: ...@@ -873,7 +765,7 @@ class OpenAIServing:
tokenizer, tokenizer,
model_config=model_config, model_config=model_config,
) )
conversation, mm_data_future = parse_chat_messages_futures( conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
messages, messages,
model_config, model_config,
tokenizer, tokenizer,
...@@ -925,8 +817,8 @@ class OpenAIServing: ...@@ -925,8 +817,8 @@ class OpenAIServing:
if tokenizer is None: if tokenizer is None:
assert isinstance(request_prompt, str), ( assert isinstance(request_prompt, str), (
"Prompt has to be a string", \ "Prompt has to be a string",
"when the tokenizer is not initialised" "when the tokenizer is not initialised",
) )
prompt_inputs = TextTokensPrompt(prompt=request_prompt, prompt_inputs = TextTokensPrompt(prompt=request_prompt,
prompt_token_ids=[1]) prompt_token_ids=[1])
...@@ -943,12 +835,17 @@ class OpenAIServing: ...@@ -943,12 +835,17 @@ class OpenAIServing:
"Prompt has to be either a string or a list of token ids") "Prompt has to be either a string or a list of token ids")
prompt_inputs = TextTokensPrompt( prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(request_prompt), prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt) prompt_token_ids=request_prompt,
)
engine_prompt = EngineTokensPrompt( engine_prompt = EngineTokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"]) prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None: if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data engine_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids
if request.mm_processor_kwargs is not None: if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
...@@ -1007,49 +904,15 @@ class OpenAIServing: ...@@ -1007,49 +904,15 @@ class OpenAIServing:
prompt_token_ids=prompt_token_ids) prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids request_prompt = prompt_token_ids
# Update the sampling params. # Update the sampling params.
sampling_params.max_tokens = (self.max_model_len - sampling_params.max_tokens = self.max_model_len - len(
len(prompt_token_ids)) prompt_token_ids)
# OPTIMIZATION # OPTIMIZATION
priority = orig_priority - 1 priority = orig_priority - 1
@staticmethod
def _load_prompt_embeds(
prompt_embeds: Optional[Union[bytes, list[bytes]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
) -> list[EmbedsPrompt]:
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
tensor = torch.load(io.BytesIO(
pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"))
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:]
return {"prompt_embeds": tensor}
if prompt_embeds:
if isinstance(prompt_embeds, list):
return [
_load_and_validate_embed(embed) for embed in prompt_embeds
]
else:
return [_load_and_validate_embed(prompt_embeds)]
else:
return []
def _log_inputs( def _log_inputs(
self, self,
request_id: str, request_id: str,
inputs: RequestPrompt, inputs: Union[RequestPrompt, PromptType],
params: Optional[Union[SamplingParams, PoolingParams, params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]], BeamSearchParams]],
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
...@@ -1061,11 +924,9 @@ class OpenAIServing: ...@@ -1061,11 +924,9 @@ class OpenAIServing:
prompt = inputs prompt = inputs
elif isinstance(inputs, list): elif isinstance(inputs, list):
prompt_token_ids = inputs prompt_token_ids = inputs
elif 'prompt_embeds' in inputs:
prompt_embeds = inputs.get("prompt_embeds")
else: else:
prompt = inputs["prompt"] prompt = getattr(inputs, 'prompt', None)
prompt_token_ids = inputs["prompt_token_ids"] prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
self.request_logger.log_inputs( self.request_logger.log_inputs(
request_id, request_id,
...@@ -1101,10 +962,12 @@ class OpenAIServing: ...@@ -1101,10 +962,12 @@ class OpenAIServing:
return raw_request.headers.get("X-Request-Id", default) return raw_request.headers.get("X-Request-Id", default)
@staticmethod @staticmethod
def _get_decoded_token(logprob: Logprob, def _get_decoded_token(
token_id: int, logprob: Logprob,
tokenizer: AnyTokenizer, token_id: int,
return_as_token_id: bool = False) -> str: tokenizer: AnyTokenizer,
return_as_token_id: bool = False,
) -> str:
if return_as_token_id: if return_as_token_id:
return f"token_id:{token_id}" return f"token_id:{token_id}"
...@@ -1117,19 +980,10 @@ class OpenAIServing: ...@@ -1117,19 +980,10 @@ class OpenAIServing:
return True return True
return self.models.is_base_model(model_name) return self.models.is_base_model(model_name)
def _get_model_name(self,
model_name: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> str:
if lora_request:
return lora_request.lora_name
if not model_name:
return self.models.base_model_paths[0].name
return model_name
def clamp_prompt_logprobs( def clamp_prompt_logprobs(
prompt_logprobs: Union[PromptLogprobs, prompt_logprobs: Union[PromptLogprobs,
None]) -> Union[PromptLogprobs, None]: None], ) -> Union[PromptLogprobs, None]:
if prompt_logprobs is None: if prompt_logprobs is None:
return prompt_logprobs return prompt_logprobs
...@@ -1137,6 +991,6 @@ def clamp_prompt_logprobs( ...@@ -1137,6 +991,6 @@ def clamp_prompt_logprobs(
if logprob_dict is None: if logprob_dict is None:
continue continue
for logprob_values in logprob_dict.values(): for logprob_values in logprob_dict.values():
if logprob_values.logprob == float('-inf'): if logprob_values.logprob == float("-inf"):
logprob_values.logprob = -9999.0 logprob_values.logprob = -9999.0
return prompt_logprobs return prompt_logprobs
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import asyncio import asyncio
import base64 import base64
import time import time
from collections.abc import AsyncGenerator, Sequence from collections.abc import AsyncGenerator
from typing import Final, Literal, Optional, Union, cast from typing import Final, Literal, Optional, Union, cast
import jinja2 import jinja2
...@@ -26,8 +26,9 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ...@@ -26,8 +26,9 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
PoolingRequest, PoolingResponse, PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo) PoolingResponseData, UsageInfo)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
...@@ -90,7 +91,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -90,7 +91,7 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
model_name = self._get_model_name(request.model) model_name = self.models.model_name()
request_id = f"pool-{self._base_request_id(raw_request)}" request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time()) created_time = int(time.time())
...@@ -104,6 +105,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -104,6 +105,7 @@ class OpenAIServingPooling(OpenAIServing):
else: else:
tokenizer = await self.engine_client.get_tokenizer(lora_request tokenizer = await self.engine_client.get_tokenizer(lora_request
) )
renderer = self._get_renderer(tokenizer)
if getattr(request, "dimensions", None) is not None: if getattr(request, "dimensions", None) is not None:
return self.create_error_response( return self.create_error_response(
...@@ -126,14 +128,11 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -126,14 +128,11 @@ class OpenAIServingPooling(OpenAIServing):
engine_prompts = await self.io_processor.pre_process_async( engine_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id) prompt=validated_prompt, request_id=request_id)
request_prompts: Sequence[RequestPrompt] = [
""
] * len(engine_prompts)
elif isinstance(request, PoolingChatRequest): elif isinstance(request, PoolingChatRequest):
( (
_, _,
request_prompts, _,
engine_prompts, engine_prompts,
) = await self._preprocess_chat( ) = await self._preprocess_chat(
request, request,
...@@ -149,13 +148,10 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -149,13 +148,10 @@ class OpenAIServingPooling(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
elif isinstance(request, PoolingCompletionRequest): elif isinstance(request, PoolingCompletionRequest):
(request_prompts, engine_prompts = await renderer.render_prompt(
engine_prompts) = await self._preprocess_completion( prompt_or_prompts=request.input,
request, config=self._build_render_config(request),
tokenizer, )
request.input,
add_special_tokens=request.add_special_tokens,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported request of type {type(request)}") f"Unsupported request of type {type(request)}")
...@@ -177,7 +173,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -177,7 +173,7 @@ class OpenAIServingPooling(OpenAIServing):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
request_prompts[i], engine_prompt,
params=pooling_params, params=pooling_params,
lora_request=lora_request) lora_request=lora_request)
...@@ -272,3 +268,10 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -272,3 +268,10 @@ class OpenAIServingPooling(OpenAIServing):
data=items, data=items,
usage=usage, usage=usage,
) )
def _build_render_config(
self, request: PoolingCompletionRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens)
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
import asyncio import asyncio
import json import json
import time import time
import uuid
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterator, Sequence from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from copy import copy from copy import copy
...@@ -24,7 +26,8 @@ from openai.types.responses import (ResponseCreatedEvent, ...@@ -24,7 +26,8 @@ from openai.types.responses import (ResponseCreatedEvent,
ResponseOutputMessage, ResponseOutputText, ResponseOutputMessage, ResponseOutputText,
ResponseReasoningItem, ResponseReasoningItem,
ResponseReasoningTextDeltaEvent, ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent) ResponseReasoningTextDoneEvent,
response_text_delta_event)
from openai.types.responses.response_output_text import (Logprob, from openai.types.responses.response_output_text import (Logprob,
LogprobTopLogprob) LogprobTopLogprob)
# yapf: enable # yapf: enable
...@@ -41,12 +44,13 @@ from vllm.entrypoints.context import (ConversationContext, HarmonyContext, ...@@ -41,12 +44,13 @@ from vllm.entrypoints.context import (ConversationContext, HarmonyContext,
SimpleContext, StreamingHarmonyContext) SimpleContext, StreamingHarmonyContext)
from vllm.entrypoints.harmony_utils import ( from vllm.entrypoints.harmony_utils import (
get_developer_message, get_stop_tokens_for_assistant_actions, get_developer_message, get_stop_tokens_for_assistant_actions,
get_system_message, get_user_message, parse_output_message, get_system_message, get_user_message, has_custom_tools,
parse_remaining_state, parse_response_input, render_for_completion) parse_output_message, parse_remaining_state, parse_response_input,
render_for_completion)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.entrypoints.openai.protocol import (ErrorResponse, from vllm.entrypoints.openai.protocol import (DeltaMessage, ErrorResponse,
InputTokensDetails, InputTokensDetails,
OutputTokensDetails, OutputTokensDetails,
RequestResponseMetadata, RequestResponseMetadata,
...@@ -55,14 +59,14 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ...@@ -55,14 +59,14 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.tool_server import MCPToolServer, ToolServer from vllm.entrypoints.tool_server import ToolServer
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob as SampleLogprob
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -168,6 +172,11 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -168,6 +172,11 @@ class OpenAIServingResponses(OpenAIServing):
# never remove messages from the store. # never remove messages from the store.
self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {} self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {}
# HACK(wuhang): This is a hack. We should use a better store.
# FIXME: If enable_store=True, this may cause a memory leak since we
# never remove events from the store.
self.event_store: dict[str, tuple[deque[str], asyncio.Event]] = {}
self.background_tasks: dict[str, asyncio.Task] = {} self.background_tasks: dict[str, asyncio.Task] = {}
self.tool_server = tool_server self.tool_server = tool_server
...@@ -228,7 +237,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -228,7 +237,7 @@ class OpenAIServingResponses(OpenAIServing):
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
model_name = self._get_model_name(request.model, lora_request) model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
if self.use_harmony: if self.use_harmony:
...@@ -249,15 +258,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -249,15 +258,6 @@ class OpenAIServingResponses(OpenAIServing):
if raw_request: if raw_request:
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
if self.tool_server is not None and isinstance(
self.tool_server,
MCPToolServer) and request.stream and request.tools and any(
tool.type in ["web_search_preview", "code_interpreter"]
for tool in request.tools):
return self.create_error_response(
"MCP tool server is not supported in background mode and "
"streaming mode")
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[ConversationContext, None]] = [] generators: list[AsyncGenerator[ConversationContext, None]] = []
...@@ -267,6 +267,8 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -267,6 +267,8 @@ class OpenAIServingResponses(OpenAIServing):
builtin_tool_list.append("browser") builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"): if self.tool_server.has_tool("python"):
builtin_tool_list.append("python") builtin_tool_list.append("python")
if self.tool_server.has_tool("container"):
builtin_tool_list.append("container")
if self.tool_server is not None: if self.tool_server is not None:
available_tools = builtin_tool_list available_tools = builtin_tool_list
...@@ -329,25 +331,44 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -329,25 +331,44 @@ class OpenAIServingResponses(OpenAIServing):
self.response_store[response.id] = response self.response_store[response.id] = response
# Run the request in the background. # Run the request in the background.
task = asyncio.create_task( if request.stream:
self._run_background_request( task = asyncio.create_task(
request, self._run_background_request_stream(
sampling_params, request,
result_generator, sampling_params,
context, result_generator,
model_name, context,
tokenizer, model_name,
request_metadata, tokenizer,
created_time, request_metadata,
), created_time,
name=f"create_{response.id}", ),
) name=f"create_{request.request_id}",
)
else:
task = asyncio.create_task(
self._run_background_request(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
),
name=f"create_{response.id}",
)
# For cleanup. # For cleanup.
response_id = response.id response_id = response.id
self.background_tasks[response_id] = task self.background_tasks[response_id] = task
task.add_done_callback( task.add_done_callback(
lambda _: self.background_tasks.pop(response_id, None)) lambda _: self.background_tasks.pop(response_id, None))
if request.stream:
return self.responses_background_stream_generator(
request.request_id)
return response return response
if request.stream: if request.stream:
...@@ -430,7 +451,8 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -430,7 +451,8 @@ class OpenAIServingResponses(OpenAIServing):
async with AsyncExitStack() as exit_stack: async with AsyncExitStack() as exit_stack:
try: try:
await context.init_tool_sessions(self.tool_server, exit_stack) await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id)
async for _ in result_generator: async for _ in result_generator:
pass pass
except asyncio.CancelledError: except asyncio.CancelledError:
...@@ -442,11 +464,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -442,11 +464,7 @@ class OpenAIServingResponses(OpenAIServing):
if self.use_harmony: if self.use_harmony:
assert isinstance(context, HarmonyContext) assert isinstance(context, HarmonyContext)
output = self._make_response_output_items_with_harmony(context) output = self._make_response_output_items_with_harmony(context)
# TODO: these are all 0 for now! num_tool_output_tokens = context.num_tool_output_tokens
num_prompt_tokens = context.num_prompt_tokens
num_generated_tokens = context.num_output_tokens
num_cached_tokens = context.num_cached_tokens
num_reasoning_tokens = context.num_reasoning_tokens
else: else:
assert isinstance(context, SimpleContext) assert isinstance(context, SimpleContext)
final_res = context.last_output final_res = context.last_output
...@@ -459,10 +477,13 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -459,10 +477,13 @@ class OpenAIServingResponses(OpenAIServing):
# Calculate usage. # Calculate usage.
assert final_res.prompt_token_ids is not None assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids) num_tool_output_tokens = 0
num_generated_tokens = len(final_output.token_ids)
num_cached_tokens = final_res.num_cached_tokens assert isinstance(context, (SimpleContext, HarmonyContext))
num_reasoning_tokens = 0 num_prompt_tokens = context.num_prompt_tokens
num_generated_tokens = context.num_output_tokens
num_cached_tokens = context.num_cached_tokens
num_reasoning_tokens = context.num_reasoning_tokens
usage = ResponseUsage( usage = ResponseUsage(
input_tokens=num_prompt_tokens, input_tokens=num_prompt_tokens,
...@@ -471,7 +492,8 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -471,7 +492,8 @@ class OpenAIServingResponses(OpenAIServing):
input_tokens_details=InputTokensDetails( input_tokens_details=InputTokensDetails(
cached_tokens=num_cached_tokens), cached_tokens=num_cached_tokens),
output_tokens_details=OutputTokensDetails( output_tokens_details=OutputTokensDetails(
reasoning_tokens=num_reasoning_tokens), reasoning_tokens=num_reasoning_tokens,
tool_output_tokens=num_tool_output_tokens),
) )
response = ResponsesResponse.from_request( response = ResponsesResponse.from_request(
request, request,
...@@ -537,6 +559,28 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -537,6 +559,28 @@ class OpenAIServingResponses(OpenAIServing):
)) ))
return out return out
def _create_stream_response_logprobs(
self,
token_ids: Sequence[int],
logprobs: Optional[SampleLogprobs],
tokenizer: AnyTokenizer,
top_logprobs: Optional[int] = None
) -> list[response_text_delta_event.Logprob]:
lgs = self._create_response_logprobs(token_ids=token_ids,
logprobs=logprobs,
tokenizer=tokenizer,
top_logprobs=top_logprobs)
return [
response_text_delta_event.Logprob(
token=lg.token,
logprob=lg.logprob,
top_logprobs=[
response_text_delta_event.LogprobTopLogprob(
token=tl.token, logprob=tl.logprob)
for tl in lg.top_logprobs
]) for lg in lgs
]
def _make_response_output_items( def _make_response_output_items(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
...@@ -670,13 +714,21 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -670,13 +714,21 @@ class OpenAIServingResponses(OpenAIServing):
# New conversation. # New conversation.
reasoning_effort = (request.reasoning.effort reasoning_effort = (request.reasoning.effort
if request.reasoning else None) if request.reasoning else None)
# Temporary: OpenAI types doesn't have container tool
# so we used MCP to cover that, up for change
tool_types = [tool.type for tool in request.tools] tool_types = [tool.type for tool in request.tools]
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
tool_types.append("container")
enable_browser = ("web_search_preview" in tool_types enable_browser = ("web_search_preview" in tool_types
and self.tool_server is not None and self.tool_server is not None
and self.tool_server.has_tool("browser")) and self.tool_server.has_tool("browser"))
enable_code_interpreter = ("code_interpreter" in tool_types enable_code_interpreter = ("code_interpreter" in tool_types
and self.tool_server is not None and self.tool_server is not None
and self.tool_server.has_tool("python")) and self.tool_server.has_tool("python"))
enable_container = ("container" in tool_types
and self.tool_server is not None
and self.tool_server.has_tool("container"))
with_custom_tools = has_custom_tools(tool_types)
sys_msg = get_system_message( sys_msg = get_system_message(
reasoning_effort=reasoning_effort, reasoning_effort=reasoning_effort,
browser_description=self.tool_server.get_tool_description( browser_description=self.tool_server.get_tool_description(
...@@ -685,11 +737,17 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -685,11 +737,17 @@ class OpenAIServingResponses(OpenAIServing):
python_description=self.tool_server.get_tool_description( python_description=self.tool_server.get_tool_description(
"python") if enable_code_interpreter "python") if enable_code_interpreter
and self.tool_server is not None else None, and self.tool_server is not None else None,
container_description=self.tool_server.get_tool_description(
"container")
if enable_container and self.tool_server is not None else None,
instructions=request.instructions,
with_custom_tools=with_custom_tools,
) )
messages.append(sys_msg) messages.append(sys_msg)
dev_msg = get_developer_message(request.instructions, if with_custom_tools:
request.tools) dev_msg = get_developer_message(
messages.append(dev_msg) instructions=request.instructions, tools=request.tools)
messages.append(dev_msg)
else: else:
# Continue the previous conversation. # Continue the previous conversation.
# FIXME(woosuk): Currently, request params like reasoning and # FIXME(woosuk): Currently, request params like reasoning and
...@@ -717,7 +775,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -717,7 +775,7 @@ class OpenAIServingResponses(OpenAIServing):
prev_msgs.append(msg) prev_msgs.append(msg)
messages.extend(prev_msgs) messages.extend(prev_msgs)
# Append the new input. # Append the new input.
# Reponses API supports simple text inputs without chat format. # Responses API supports simple text inputs without chat format.
if isinstance(request.input, str): if isinstance(request.input, str):
messages.append(get_user_message(request.input)) messages.append(get_user_message(request.input))
else: else:
...@@ -728,7 +786,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -728,7 +786,7 @@ class OpenAIServingResponses(OpenAIServing):
for response_msg in request.input: for response_msg in request.input:
messages.append( messages.append(
parse_response_input(response_msg, prev_outputs)) parse_response_input(response_msg, prev_outputs))
# User passes in a a tool call request and its output. We need # User passes in a tool call request and its output. We need
# to add the tool call request to prev_outputs so that the # to add the tool call request to prev_outputs so that the
# parse_response_input can find the tool call request when # parse_response_input can find the tool call request when
# parsing the tool call output. # parsing the tool call output.
...@@ -736,6 +794,40 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -736,6 +794,40 @@ class OpenAIServingResponses(OpenAIServing):
prev_outputs.append(response_msg) prev_outputs.append(response_msg)
return messages return messages
async def _run_background_request_stream(
self,
request: ResponsesRequest,
*args,
**kwargs,
):
event_deque: deque[str] = deque()
new_event_signal = asyncio.Event()
self.event_store[request.request_id] = (event_deque, new_event_signal)
response = None
try:
generator = self.responses_stream_generator(
request, *args, **kwargs)
async for event in generator:
event_deque.append(event)
new_event_signal.set() # Signal new event available
except Exception as e:
logger.exception("Background request failed for %s",
request.request_id)
response = self.create_error_response(str(e))
finally:
# Mark as finished with a special marker
event_deque.append("__STREAM_END__")
new_event_signal.set()
if response is not None and isinstance(response, ErrorResponse):
# If the request has failed, update the status to "failed".
response_id = request.request_id
async with self.response_store_lock:
stored_response = self.response_store.get(response_id)
assert stored_response is not None
if stored_response.status not in ("completed", "cancelled"):
stored_response.status = "failed"
async def _run_background_request( async def _run_background_request(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
...@@ -759,9 +851,36 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -759,9 +851,36 @@ class OpenAIServingResponses(OpenAIServing):
if stored_response.status not in ("completed", "cancelled"): if stored_response.status not in ("completed", "cancelled"):
stored_response.status = "failed" stored_response.status = "failed"
async def responses_background_stream_generator(
self,
response_id: str,
starting_after: Optional[int] = None,
):
if response_id not in self.event_store:
raise ValueError(f"Unknown response_id: {response_id}")
event_deque, new_event_signal = self.event_store[response_id]
start_index = 0 if starting_after is None else starting_after + 1
current_index = start_index
while True:
new_event_signal.clear()
# Yield existing events from start_index
while current_index < len(event_deque):
event = event_deque[current_index]
if event == "__STREAM_END__":
return
yield event
current_index += 1
await new_event_signal.wait()
async def retrieve_responses( async def retrieve_responses(
self, self,
response_id: str, response_id: str,
starting_after: Optional[int],
stream: Optional[bool],
) -> Union[ErrorResponse, ResponsesResponse]: ) -> Union[ErrorResponse, ResponsesResponse]:
if not response_id.startswith("resp_"): if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id) return self._make_invalid_id_error(response_id)
...@@ -771,6 +890,12 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -771,6 +890,12 @@ class OpenAIServingResponses(OpenAIServing):
if response is None: if response is None:
return self._make_not_found_error(response_id) return self._make_not_found_error(response_id)
if stream:
return self.responses_background_stream_generator(
response_id,
starting_after,
)
return response return response
async def cancel_responses( async def cancel_responses(
...@@ -829,7 +954,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -829,7 +954,7 @@ class OpenAIServingResponses(OpenAIServing):
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
) )
async def _process_streaming_events( async def _process_simple_streaming_events(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
sampling_params: SamplingParams, sampling_params: SamplingParams,
...@@ -839,47 +964,292 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -839,47 +964,292 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: int, created_time: int,
_send_event: Callable[[BaseModel], str],
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
sequence_number = 0 current_content_index = 0
current_output_index = 0
def _send_event(event: BaseModel): current_item_id = ""
nonlocal sequence_number reasoning_parser = None
# Set sequence_number if the event has this attribute if self.reasoning_parser:
if hasattr(event, 'sequence_number'): reasoning_parser = self.reasoning_parser(tokenizer)
event.sequence_number = sequence_number previous_text = ""
sequence_number += 1 previous_token_ids: list[int] = []
# Get event type from the event's type field if it exists first_delta_sent = False
event_type = getattr(event, 'type', 'unknown') previous_delta_messages: list[DeltaMessage] = []
return (f"event: {event_type}\n" async for ctx in result_generator:
f"data: {event.model_dump_json(indent=None)}\n\n") assert isinstance(ctx, SimpleContext)
if ctx.last_output is None:
continue
if ctx.last_output.outputs:
output = ctx.last_output.outputs[0]
if reasoning_parser:
delta_message = \
reasoning_parser.extract_reasoning_content_streaming(
previous_text=previous_text,
current_text=previous_text + output.text,
delta_text=output.text,
previous_token_ids=previous_token_ids,
current_token_ids=previous_token_ids +
output.token_ids,
delta_token_ids=output.token_ids,
)
else:
delta_message = DeltaMessage(content=output.text, )
previous_text += output.text
previous_token_ids += output.token_ids
if not delta_message:
continue
if not first_delta_sent:
current_item_id = str(uuid.uuid4())
if delta_message.reasoning_content:
yield _send_event(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseReasoningItem(
type="reasoning",
id=current_item_id,
summary=[],
status="in_progress",
),
))
else:
yield _send_event(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
content=[],
status="in_progress",
),
))
yield _send_event(
openai_responses_types.ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=openai_responses_types.ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
))
current_content_index += 1
first_delta_sent = True
# todo(kebe7jun) tool call support
# check delta message and previous delta message are
# same as content or reasoning content
if (previous_delta_messages
and previous_delta_messages[-1].reasoning_content
is not None and delta_message.content is not None):
# from reasoning to normal content, send done
# event for reasoning
reason_content = ''.join(
pm.reasoning_content for pm in previous_delta_messages
if pm.reasoning_content is not None)
yield _send_event(
ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done",
item_id=current_item_id,
sequence_number=-1,
output_index=current_output_index,
content_index=current_content_index,
text=reason_content,
))
current_content_index = 0
reasoning_item = ResponseReasoningItem(
type="reasoning",
content=[
ResponseReasoningTextContent(
text=reason_content,
type="reasoning_text",
),
],
status="completed",
id=current_item_id,
summary=[],
)
yield _send_event(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=reasoning_item,
))
yield _send_event(
openai_responses_types.ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
content=[],
status="in_progress",
),
))
current_output_index += 1
current_item_id = str(uuid.uuid4())
yield _send_event(
openai_responses_types.ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=openai_responses_types.ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
))
current_content_index += 1
# reset previous delta messages
previous_delta_messages = []
if delta_message.reasoning_content is not None:
yield _send_event(
ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta",
sequence_number=-1,
content_index=current_content_index,
output_index=current_output_index,
item_id=current_item_id,
delta=delta_message.reasoning_content,
))
elif delta_message.content is not None:
yield _send_event(
openai_responses_types.ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
content_index=current_content_index,
output_index=current_output_index,
item_id=current_item_id,
delta=delta_message.content,
logprobs=self._create_stream_response_logprobs(
token_ids=output.token_ids,
logprobs=output.logprobs,
tokenizer=tokenizer,
top_logprobs=request.top_logprobs,
) if request.is_include_output_logprobs() else [],
))
current_content_index += 1
previous_delta_messages.append(delta_message)
if previous_delta_messages:
if previous_delta_messages[-1].reasoning_content is not None:
reason_content = ''.join(pm.reasoning_content
for pm in previous_delta_messages
if pm.reasoning_content is not None)
yield _send_event(
ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done",
item_id=current_item_id,
sequence_number=-1,
output_index=current_output_index,
content_index=current_content_index,
text=reason_content,
))
current_content_index += 1
reasoning_item = ResponseReasoningItem(
type="reasoning",
content=[
ResponseReasoningTextContent(
text=reason_content,
type="reasoning_text",
),
],
status="completed",
id=current_item_id,
summary=[],
)
yield _send_event(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=reasoning_item,
))
elif previous_delta_messages[-1].content is not None:
final_content = ''.join(pm.content
for pm in previous_delta_messages
if pm.content is not None)
yield _send_event(
openai_responses_types.ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=current_output_index,
content_index=current_content_index,
text=final_content,
logprobs=[],
item_id=current_item_id,
))
current_content_index += 1
part = ResponseOutputText(
text=final_content,
type="output_text",
annotations=[],
)
yield _send_event(
openai_responses_types.ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
part=part,
))
current_content_index += 1
item = ResponseOutputMessage(
type="message",
role="assistant",
content=[
part,
],
status="completed",
id=current_item_id,
summary=[],
)
yield _send_event(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=item,
))
async def _process_harmony_streaming_events(
self,
request: ResponsesRequest,
sampling_params: SamplingParams,
result_generator: AsyncIterator[Optional[ConversationContext]],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: int,
_send_event: Callable[[BaseModel], str],
) -> AsyncGenerator[str, None]:
current_content_index = 0 # FIXME: this number is never changed current_content_index = 0 # FIXME: this number is never changed
current_output_index = 0 current_output_index = 0
current_item_id = "" # FIXME: this number is never changed current_item_id = "" # FIXME: this number is never changed
sent_output_item_added = False sent_output_item_added = False
initial_response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=[],
status="in_progress",
usage=None,
).model_dump()
yield _send_event(
ResponseCreatedEvent(
type="response.created",
sequence_number=-1,
response=initial_response,
))
yield _send_event(
ResponseInProgressEvent(
type="response.in_progress",
sequence_number=-1,
response=initial_response,
))
async for ctx in result_generator: async for ctx in result_generator:
assert isinstance(ctx, StreamingHarmonyContext) assert isinstance(ctx, StreamingHarmonyContext)
...@@ -1229,29 +1599,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1229,29 +1599,6 @@ class OpenAIServingResponses(OpenAIServing):
), ),
)) ))
async def empty_async_generator():
# A hack to trick Python to think this is a generator but in fact
# it immediately returns.
if False:
yield
final_response = await self.responses_full_generator(
request,
sampling_params,
empty_async_generator(),
context,
model_name,
tokenizer,
request_metadata,
created_time=created_time,
)
yield _send_event(
openai_responses_types.ResponseCompletedEvent(
type="response.completed",
sequence_number=-1,
response=final_response.model_dump(),
))
async def responses_stream_generator( async def responses_stream_generator(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
...@@ -1266,16 +1613,78 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1266,16 +1613,78 @@ class OpenAIServingResponses(OpenAIServing):
# TODO: # TODO:
# 1. Handle disconnect # 1. Handle disconnect
if not isinstance(context, StreamingHarmonyContext):
raise NotImplementedError(
"Streaming is not supported for responses API without Harmony."
)
created_time = created_time or int(time.time()) created_time = created_time or int(time.time())
sequence_number = 0
def _send_event(event: BaseModel):
nonlocal sequence_number
# Set sequence_number if the event has this attribute
if hasattr(event, 'sequence_number'):
event.sequence_number = sequence_number
sequence_number += 1
# Get event type from the event's type field if it exists
event_type = getattr(event, 'type', 'unknown')
return (f"event: {event_type}\n"
f"data: {event.model_dump_json(indent=None)}\n\n")
async with AsyncExitStack() as exit_stack: async with AsyncExitStack() as exit_stack:
await context.init_tool_sessions(self.tool_server, exit_stack) processer = None
async for event_data in self._process_streaming_events( if self.use_harmony:
request, sampling_params, result_generator, context, await context.init_tool_sessions(self.tool_server, exit_stack,
model_name, tokenizer, request_metadata, created_time): request.request_id)
processer = self._process_harmony_streaming_events
else:
processer = self._process_simple_streaming_events
initial_response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=[],
status="in_progress",
usage=None,
).model_dump()
yield _send_event(
ResponseCreatedEvent(
type="response.created",
sequence_number=-1,
response=initial_response,
))
yield _send_event(
ResponseInProgressEvent(
type="response.in_progress",
sequence_number=-1,
response=initial_response,
))
async for event_data in processer(request, sampling_params,
result_generator, context,
model_name, tokenizer,
request_metadata, created_time,
_send_event):
yield event_data yield event_data
async def empty_async_generator():
# A hack to trick Python to think this is a generator but
# in fact it immediately returns.
if False:
yield
final_response = await self.responses_full_generator(
request,
sampling_params,
empty_async_generator(),
context,
model_name,
tokenizer,
request_metadata,
created_time=created_time,
)
yield _send_event(
openai_responses_types.ResponseCompletedEvent(
type="response.completed",
sequence_number=-1,
response=final_response.model_dump(),
))
...@@ -353,7 +353,7 @@ class ServingScores(OpenAIServing): ...@@ -353,7 +353,7 @@ class ServingScores(OpenAIServing):
final_res_batch, final_res_batch,
request_id, request_id,
created_time, created_time,
self._get_model_name(request.model), self.models.model_name(),
) )
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
...@@ -399,7 +399,7 @@ class ServingScores(OpenAIServing): ...@@ -399,7 +399,7 @@ class ServingScores(OpenAIServing):
return self.request_output_to_rerank_response( return self.request_output_to_rerank_response(
final_res_batch, final_res_batch,
request_id, request_id,
self._get_model_name(request.model), self.models.model_name(),
documents, documents,
top_n, top_n,
) )
......
...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest, ...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
...@@ -65,13 +66,14 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -65,13 +66,14 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
renderer = self._get_renderer(tokenizer)
if isinstance(request, TokenizeChatRequest): if isinstance(request, TokenizeChatRequest):
tool_dicts = (None if request.tools is None else tool_dicts = (None if request.tools is None else
[tool.model_dump() for tool in request.tools]) [tool.model_dump() for tool in request.tools])
( (
_, _,
request_prompts, _,
engine_prompts, engine_prompts,
) = await self._preprocess_chat( ) = await self._preprocess_chat(
request, request,
...@@ -87,21 +89,18 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -87,21 +89,18 @@ class OpenAIServingTokenization(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
else: else:
(request_prompts, engine_prompts = await renderer.render_prompt(
engine_prompts) = await self._preprocess_completion( prompt_or_prompts=request.prompt,
request, config=self._build_render_config(request),
tokenizer, )
request.prompt,
add_special_tokens=request.add_special_tokens,
)
except (ValueError, TypeError, jinja2.TemplateError) as e: except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}") return self.create_error_response(f"{e} {e.__cause__}")
input_ids: list[int] = [] input_ids: list[int] = []
for i, engine_prompt in enumerate(engine_prompts): for engine_prompt in engine_prompts:
self._log_inputs(request_id, self._log_inputs(request_id,
request_prompts[i], engine_prompt,
params=None, params=None,
lora_request=lora_request) lora_request=lora_request)
...@@ -158,6 +157,9 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -158,6 +157,9 @@ class OpenAIServingTokenization(OpenAIServing):
return self.create_error_response( return self.create_error_response(
f"Failed to get tokenizer info: {str(e)}") f"Failed to get tokenizer info: {str(e)}")
def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
return RenderConfig(add_special_tokens=request.add_special_tokens)
@dataclass @dataclass
class TokenizerInfo: class TokenizerInfo:
......
...@@ -89,6 +89,9 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -89,6 +89,9 @@ class OpenAISpeechToText(OpenAIServing):
) -> tuple[list[PromptType], float]: ) -> tuple[list[PromptType], float]:
# Validate request # Validate request
language = self.model_cls.validate_language(request.language) language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper.
to_language = self.model_cls.validate_language(request.to_language) \
if request.to_language else None
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb: if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise ValueError("Maximum file size exceeded.") raise ValueError("Maximum file size exceeded.")
...@@ -112,7 +115,9 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -112,7 +115,9 @@ class OpenAISpeechToText(OpenAIServing):
model_config=self.model_config, model_config=self.model_config,
language=language, language=language,
task_type=self.task_type, task_type=self.task_type,
request_prompt=request.prompt) request_prompt=request.prompt,
to_language=to_language,
)
prompts.append(prompt) prompts.append(prompt)
return prompts, duration return prompts, duration
......
...@@ -16,6 +16,7 @@ from .llama4_pythonic_tool_parser import Llama4PythonicToolParser ...@@ -16,6 +16,7 @@ from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
from .llama_tool_parser import Llama3JsonToolParser from .llama_tool_parser import Llama3JsonToolParser
from .minimax_tool_parser import MinimaxToolParser from .minimax_tool_parser import MinimaxToolParser
from .mistral_tool_parser import MistralToolParser from .mistral_tool_parser import MistralToolParser
from .openai_tool_parser import OpenAIToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser
...@@ -46,4 +47,5 @@ __all__ = [ ...@@ -46,4 +47,5 @@ __all__ = [
"Qwen3CoderToolParser", "Qwen3CoderToolParser",
"SeedOssToolParser", "SeedOssToolParser",
"Step3ToolParser", "Step3ToolParser",
"OpenAIToolParser",
] ]
...@@ -35,7 +35,7 @@ class Internlm2ToolParser(ToolParser): ...@@ -35,7 +35,7 @@ class Internlm2ToolParser(ToolParser):
self, request: ChatCompletionRequest) -> ChatCompletionRequest: self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none': if request.tools and request.tool_choice != 'none':
# do not skip special tokens because internlm use the special # do not skip special tokens because internlm use the special
# tokens to indicated the start and end of the tool calls # tokens to indicate the start and end of the tool calls
# information. # information.
request.skip_special_tokens = False request.skip_special_tokens = False
return request return request
...@@ -60,8 +60,8 @@ class Internlm2ToolParser(ToolParser): ...@@ -60,8 +60,8 @@ class Internlm2ToolParser(ToolParser):
if '<|action_start|>' not in current_text: if '<|action_start|>' not in current_text:
self.position = len(current_text) self.position = len(current_text)
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
# if the tool call is sended, return a empty delta message # if the tool call is sent, return an empty delta message
# to make sure the finish_reason will be send correctly. # to make sure the finish_reason will be sent correctly.
if self.current_tool_id > 0: if self.current_tool_id > 0:
return DeltaMessage(content='') return DeltaMessage(content='')
...@@ -89,7 +89,7 @@ class Internlm2ToolParser(ToolParser): ...@@ -89,7 +89,7 @@ class Internlm2ToolParser(ToolParser):
try: try:
parsable_arr = action parsable_arr = action
# tool calls are generated in an object in inernlm2 # tool calls are generated in an object in internlm2
# it's not support parallel tool calls # it's not support parallel tool calls
try: try:
tool_call_arr: dict = partial_json_parser.loads( tool_call_arr: dict = partial_json_parser.loads(
......
...@@ -176,7 +176,7 @@ class Llama4PythonicToolParser(ToolParser): ...@@ -176,7 +176,7 @@ class Llama4PythonicToolParser(ToolParser):
index] += delta.function.arguments index] += delta.function.arguments
# HACK: serving_chat.py inspects the internal state of tool parsers # HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically # when determining its final streaming delta, automatically
# adding autocompleted JSON. # adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason # These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called. # is set to tool_calls when at least one tool is called.
......
...@@ -143,7 +143,7 @@ class MistralToolParser(ToolParser): ...@@ -143,7 +143,7 @@ class MistralToolParser(ToolParser):
except json.JSONDecodeError: except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call. # use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained # NOTE: This use case should not happen if the model is trained
# correctly. It's a easy possible fix so it's included, but # correctly. It's an easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls # can be brittle for very complex / highly nested tool calls
raw_tool_call = self.tool_call_regex.findall(tool_content)[0] raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
function_call_arr = json.loads(raw_tool_call) function_call_arr = json.loads(raw_tool_call)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.harmony_utils import parse_output_into_messages
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ToolParserManager.register_module("openai")
class OpenAIToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
token_ids: Sequence[int] | None = None,
) -> ExtractedToolCallInformation:
if token_ids is None:
raise NotImplementedError(
"OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501
)
parser = parse_output_into_messages(token_ids)
tool_calls = []
final_content = None
if len(parser.messages) > 0:
for msg in parser.messages:
if msg.recipient and msg.recipient.startswith("functions."):
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=msg.recipient.split("functions.")[1],
arguments=msg.content[0].text,
),
))
elif msg.channel == "final":
final_content = msg.content[0].text
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
content=final_content,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
raise NotImplementedError(
"Not being used, manual parsing in serving_chat.py" # noqa: E501
)
...@@ -165,7 +165,7 @@ class PythonicToolParser(ToolParser): ...@@ -165,7 +165,7 @@ class PythonicToolParser(ToolParser):
index] += delta.function.arguments index] += delta.function.arguments
# HACK: serving_chat.py inspects the internal state of tool parsers # HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically # when determining its final streaming delta, automatically
# adding autocompleted JSON. # adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason # These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called. # is set to tool_calls when at least one tool is called.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Annotated, Optional, Union
import pybase64
import torch
from pydantic import Field
from vllm.config import ModelConfig
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AsyncMicrobatchTokenizer
@dataclass(frozen=True)
class RenderConfig:
"""Configuration to control how prompts are prepared."""
max_length: Optional[int] = None
"""Maximum allowable total input token length. If provided,
token inputs longer than this raise ``ValueError``."""
truncate_prompt_tokens: Optional[int] = None
"""Number of tokens to keep. ``None`` means no truncation.
``0`` yields an empty list (and skips embeds).
``-1`` maps to ``model_config.max_model_len``."""
add_special_tokens: Optional[bool] = True
"""Whether to add model-specific special tokens during tokenization."""
cache_salt: Optional[str] = None
"""String to disambiguate prefix cache entries."""
needs_detokenization: Optional[bool] = False
"""If True, detokenize IDs back to text for inclusion in outputs."""
class BaseRenderer(ABC):
"""
Base class for unified input processing and rendering.
The Renderer serves as a unified input processor that consolidates
tokenization, chat template formatting, and multimodal input handling
into a single component.
It converts high-level API requests (OpenAI-style JSON) into token IDs and
multimodal features ready for engine consumption.
Key responsibilities:
- Convert text prompts to token sequences with proper special tokens
- Apply chat templates and format conversations
- Handle multimodal inputs (images, audio, etc.) when applicable
- Manage prompt truncation and length validation
- Provide clean separation between API layer and engine core
"""
def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[AnyTokenizer] = None,
):
super().__init__()
self.model_config = model_config
self.tokenizer = tokenizer
@abstractmethod
async def render_prompt(
self,
*,
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
config: "RenderConfig",
) -> list[EngineTokensPrompt]:
"""
Convert text or token inputs into engine-ready TokensPrompt objects.
This method accepts text or token inputs and produces a
list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects
for the engine.
Args:
prompt_or_prompts: One of:
- ``str``: Single text prompt.
- ``list[str]``: Batch of text prompts.
- ``list[int]``: Single pre-tokenized sequence.
- ``list[list[int]]``: Batch of pre-tokenized sequences.
config: Render configuration controlling how prompts are prepared
(e.g., tokenization and length handling).
Returns:
list[EngineTokensPrompt]: Engine-ready token prompts.
Raises:
ValueError: If input formats are invalid or length limits exceeded.
"""
raise NotImplementedError
@abstractmethod
async def render_prompt_and_embeds(
self,
*,
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
config: "RenderConfig",
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
"""
Convert text/token and/or base64-encoded embeddings inputs into
engine-ready prompt objects using a unified RenderConfig.
At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be
provided and non-empty. If both are omitted or empty (e.g., empty
string and empty list), a ``ValueError`` is raised.
Args:
prompt_or_prompts: Text or token inputs to include.
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
torch-saved tensor to be used as prompt embeddings.
config: Render configuration controlling how prompts are prepared
(e.g., tokenization and length handling).
Returns:
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
Engine-ready prompt objects.
Raises:
ValueError: If both ``prompt_or_prompts`` and ``prompt_embeds``
are omitted or empty (decoder prompt cannot be empty), or if
length limits are exceeded.
"""
raise NotImplementedError
@classmethod
def load_prompt_embeds(
cls,
prompt_embeds: Union[bytes, list[bytes]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
cache_salt: Optional[str] = None,
) -> list[EngineEmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects."""
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
tensor = torch.load(
io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:]
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor)
if cache_salt is not None:
embeds_prompt["cache_salt"] = cache_salt
return embeds_prompt
if isinstance(prompt_embeds, list):
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
return [_load_and_validate_embed(prompt_embeds)]
class CompletionRenderer(BaseRenderer):
def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[AnyTokenizer] = None,
async_tokenizer_pool: Optional[dict[AnyTokenizer,
AsyncMicrobatchTokenizer]] = None,
):
super().__init__(model_config, tokenizer)
self.async_tokenizer_pool = async_tokenizer_pool
self.async_tokenizer: Optional[AsyncMicrobatchTokenizer] = None
async def render_prompt(
self,
*,
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
config: "RenderConfig",
) -> list[EngineTokensPrompt]:
"""Implementation of prompt rendering for completion-style requests.
Uses async tokenizer pooling for improved performance. See base class
for detailed parameter documentation.
"""
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
config.truncate_prompt_tokens, config.max_length)
if truncate_prompt_tokens == 0:
return []
# Parse and batch the input prompts
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
tasks = []
for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is True:
# Token input
# Note: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text.
task = self._maybe_detokenize(prompt_input["content"],
config.max_length,
truncate_prompt_tokens,
config.cache_salt,
config.needs_detokenization)
else:
# Text input
task = self._tokenize(prompt_input["content"],
config.max_length,
truncate_prompt_tokens,
config.add_special_tokens,
config.cache_salt)
tasks.append(task)
# Wait for all text tokenization to finish
if tasks:
tokenized_text_prompts = await asyncio.gather(*tasks)
return tokenized_text_prompts
return []
async def render_prompt_and_embeds(
self,
*,
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
config: "RenderConfig",
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
"""
Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
"""
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
config.truncate_prompt_tokens, config.max_length)
if truncate_prompt_tokens == 0:
return []
rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = []
if prompt_embeds is not None:
rendered.extend(
self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens,
config.cache_salt))
if prompt_or_prompts is None or prompt_or_prompts == "":
return rendered
token_prompts = await self.render_prompt(
prompt_or_prompts=prompt_or_prompts,
config=config,
)
rendered.extend(token_prompts)
return rendered
def _validate_and_normalize_truncate_tokens(
self,
truncate_prompt_tokens: Optional[int],
max_length: Optional[int],
) -> Optional[int]:
"""Validate and normalize truncate_prompt_tokens parameter."""
if truncate_prompt_tokens is None:
return None
if truncate_prompt_tokens == 0:
return 0
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = self.model_config.max_model_len
if max_length is not None and truncate_prompt_tokens > max_length:
raise ValueError(
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
f"cannot be greater than max_length ({max_length}). "
f"Please select a smaller truncation size.")
return truncate_prompt_tokens
def _maybe_apply_truncation(
self, token_ids: list[int],
truncate_prompt_tokens: Optional[int]) -> list[int]:
"""Apply truncation to token sequence."""
if truncate_prompt_tokens is None:
return token_ids
if truncate_prompt_tokens >= len(token_ids):
return token_ids
return token_ids[-truncate_prompt_tokens:]
async def _tokenize(
self,
text: str,
max_length: Optional[int],
truncate_prompt_tokens: Optional[int],
add_special_tokens: Optional[bool],
cache_salt: Optional[str],
) -> EngineTokensPrompt:
"""Tokenize text input asynchronously."""
async_tokenizer = self._get_async_tokenizer()
# Handle encoder-specific preprocessing
if (self.model_config.encoder_config is not None
and self.model_config.encoder_config.get(
"do_lower_case", False)):
text = text.lower()
# Tokenize texts
if truncate_prompt_tokens is None:
encoded = await async_tokenizer(
text, add_special_tokens=add_special_tokens)
else:
encoded = await async_tokenizer(
text,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
return self._create_tokens_prompt(encoded.input_ids, max_length,
cache_salt, text)
async def _maybe_detokenize(
self,
token_ids: list[int],
max_length: Optional[int],
truncate_prompt_tokens: Optional[int],
cache_salt: Optional[str],
needs_detokenization: Optional[bool] = False,
) -> EngineTokensPrompt:
"""Optionally detokenize token IDs and build a tokens prompt."""
token_ids = self._maybe_apply_truncation(token_ids,
truncate_prompt_tokens)
prompt = None
if needs_detokenization is True:
async_tokenizer = self._get_async_tokenizer()
prompt = await async_tokenizer.decode(token_ids)
return self._create_tokens_prompt(token_ids=token_ids,
max_length=max_length,
cache_salt=cache_salt,
prompt=prompt)
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
"""Get or create async tokenizer using shared pool."""
async_tokenizer = self.async_tokenizer
if async_tokenizer is not None:
return async_tokenizer
tokenizer = self.tokenizer
if self.tokenizer is None:
raise ValueError(
"No tokenizer available for text input processing")
if self.async_tokenizer_pool is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
else:
async_tokenizer = self.async_tokenizer_pool.get(tokenizer)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
self.async_tokenizer_pool[tokenizer] = async_tokenizer
self.async_tokenizer = async_tokenizer
return async_tokenizer
def _create_tokens_prompt(
self,
token_ids: list[int],
max_length: Optional[int] = None,
cache_salt: Optional[str] = None,
prompt: Optional[str] = None,
) -> EngineTokensPrompt:
"""Create validated EngineTokensPrompt."""
if max_length is not None and len(token_ids) > max_length:
raise ValueError(
f"This maximum context length is {max_length} tokens. "
f"However, your request has {len(token_ids)} input tokens. "
"Please reduce the length of the input messages.")
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
if cache_salt is not None:
tokens_prompt["cache_salt"] = cache_salt
if prompt is not None:
tokens_prompt["prompt"] = prompt
return tokens_prompt
...@@ -4,6 +4,8 @@ import os ...@@ -4,6 +4,8 @@ import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from openai_harmony import Author, Message, Role, TextContent
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -99,6 +101,28 @@ class HarmonyPythonTool(Tool): ...@@ -99,6 +101,28 @@ class HarmonyPythonTool(Tool):
return return
self.python_tool = PythonTool() self.python_tool = PythonTool()
async def validate(self):
if not self.enabled:
return
try:
message = Message(
author=Author(role=Role.ASSISTANT),
content=[TextContent(text="print('Hello, world!')")],
channel="analysis",
recipient="python",
content_type="code",
)
msgs = []
async for msg in self.python_tool.process(message):
msgs.append(msg)
assert msgs[0].content[0].text == "Hello, world!\n"
except Exception as e:
self.enabled = False
logger.warning_once(
"Code interpreter tool failed to initialize (%s), code "
"interpreter is disabled", e)
return
logger.info_once("Code interpreter tool initialized") logger.info_once("Code interpreter tool initialized")
async def get_result(self, context: "ConversationContext") -> Any: async def get_result(self, context: "ConversationContext") -> Any:
......
...@@ -86,7 +86,8 @@ class ToolServer(ABC): ...@@ -86,7 +86,8 @@ class ToolServer(ABC):
pass pass
@abstractmethod @abstractmethod
def new_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: def new_session(self, tool_name: str,
session_id: str) -> AbstractAsyncContextManager[Any]:
""" """
Create a session for the tool. Create a session for the tool.
""" """
...@@ -124,7 +125,8 @@ class MCPToolServer(ToolServer): ...@@ -124,7 +125,8 @@ class MCPToolServer(ToolServer):
description=tool.description, description=tool.description,
parameters=tool.inputSchema) parameters=tool.inputSchema)
for tool in list_tools_response.tools for tool in list_tools_response.tools
]) ],
)
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
if tool_from_mcp.name not in self.urls: if tool_from_mcp.name not in self.urls:
self.urls[tool_from_mcp.name] = url self.urls[tool_from_mcp.name] = url
...@@ -142,14 +144,16 @@ class MCPToolServer(ToolServer): ...@@ -142,14 +144,16 @@ class MCPToolServer(ToolServer):
return self.harmony_tool_descriptions.get(tool_name) return self.harmony_tool_descriptions.get(tool_name)
@asynccontextmanager @asynccontextmanager
async def new_session(self, tool_name: str): async def new_session(self, tool_name: str, session_id: str):
from mcp import ClientSession from mcp import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
url = self.urls.get(tool_name) url = self.urls.get(tool_name)
headers = {"x-session-id": session_id}
if not url: if not url:
raise KeyError(f"Tool '{tool_name}' is not supported") raise KeyError(f"Tool '{tool_name}' is not supported")
async with sse_client(url=url) as streams, ClientSession( async with sse_client(url=url,
*streams) as session: headers=headers) as streams, ClientSession(
*streams) as session:
await session.initialize() await session.initialize()
yield session yield session
...@@ -158,10 +162,13 @@ class DemoToolServer(ToolServer): ...@@ -158,10 +162,13 @@ class DemoToolServer(ToolServer):
def __init__(self): def __init__(self):
self.tools: dict[str, Tool] = {} self.tools: dict[str, Tool] = {}
async def init_and_validate(self):
browser_tool = HarmonyBrowserTool() browser_tool = HarmonyBrowserTool()
python_tool = HarmonyPythonTool()
await python_tool.validate()
if browser_tool.enabled: if browser_tool.enabled:
self.tools["browser"] = browser_tool self.tools["browser"] = browser_tool
python_tool = HarmonyPythonTool()
if python_tool.enabled: if python_tool.enabled:
self.tools["python"] = python_tool self.tools["python"] = python_tool
logger.info("DemoToolServer initialized with tools: %s", logger.info("DemoToolServer initialized with tools: %s",
...@@ -182,7 +189,7 @@ class DemoToolServer(ToolServer): ...@@ -182,7 +189,7 @@ class DemoToolServer(ToolServer):
raise ValueError(f"Unknown tool {tool_name}") raise ValueError(f"Unknown tool {tool_name}")
@asynccontextmanager @asynccontextmanager
async def new_session(self, tool_name: str): async def new_session(self, tool_name: str, session_id: str):
if tool_name not in self.tools: if tool_name not in self.tools:
raise KeyError(f"Tool '{tool_name}' is not supported") raise KeyError(f"Tool '{tool_name}' is not supported")
yield self.tools[tool_name] yield self.tools[tool_name]
...@@ -13,24 +13,6 @@ logger = init_logger(__name__) ...@@ -13,24 +13,6 @@ logger = init_logger(__name__)
# that interact with vllm workers. # that interact with vllm workers.
# they are executed whenever `import vllm` is called. # they are executed whenever `import vllm` is called.
if os.environ.get('NCCL_CUMEM_ENABLE', '0') != '0':
logger.warning(
"NCCL_CUMEM_ENABLE is set to %s, skipping override. "
"This may increase memory overhead with cudagraph+allreduce: "
"https://github.com/NVIDIA/nccl/issues/1234",
os.environ['NCCL_CUMEM_ENABLE'])
elif not os.path.exists('/dev/nvidia-caps-imex-channels'):
# NCCL requires NCCL_CUMEM_ENABLE to work with
# multi-node NVLink, typically on GB200-NVL72 systems.
# The ultimate way to detect multi-node NVLink is to use
# NVML APIs, which are too expensive to call here.
# As an approximation, we check the existence of
# /dev/nvidia-caps-imex-channels, used by
# multi-node NVLink to communicate across nodes.
# This will still cost some GPU memory, but it is worthwhile
# because we can get very fast cross-node bandwidth with NVLink.
os.environ['NCCL_CUMEM_ENABLE'] = '0'
# see https://github.com/vllm-project/vllm/pull/15951 # see https://github.com/vllm-project/vllm/pull/15951
# it avoids unintentional cuda initialization from torch.cuda.is_available() # it avoids unintentional cuda initialization from torch.cuda.is_available()
os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1' os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment