Unverified Commit 17ee641c authored by Bongwoo Bak's avatar Bongwoo Bak Committed by GitHub
Browse files

[Responses API] Add kv_transfer_params for PD disaggregation (#37424)


Signed-off-by: default avatarbongwoobak <bongwoobak@gmail.com>
Co-authored-by: default avatarChauncey <chaunceyjiang@gmail.com>
parent 0d50fa1d
...@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod ...@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from dataclasses import replace from dataclasses import replace
from typing import TYPE_CHECKING, Final, Union from typing import TYPE_CHECKING, Any, Final, Union
from openai.types.responses.response_function_tool_call_output_item import ( from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem, ResponseFunctionToolCallOutputItem,
...@@ -182,6 +182,7 @@ class SimpleContext(ConversationContext): ...@@ -182,6 +182,7 @@ class SimpleContext(ConversationContext):
self.all_turn_metrics = [] self.all_turn_metrics = []
self.input_messages: list[ResponseRawMessageAndToken] = [] self.input_messages: list[ResponseRawMessageAndToken] = []
self.kv_transfer_params: dict[str, Any] | None = None
def append_output(self, output) -> None: def append_output(self, output) -> None:
self.last_output = output self.last_output = output
...@@ -190,6 +191,8 @@ class SimpleContext(ConversationContext): ...@@ -190,6 +191,8 @@ class SimpleContext(ConversationContext):
self.num_prompt_tokens = len(output.prompt_token_ids or []) self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0 self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or []) self.num_output_tokens += len(output.outputs[0].token_ids or [])
if output.kv_transfer_params is not None:
self.kv_transfer_params = output.kv_transfer_params
# Accumulate text, token_ids, and logprobs for streaming mode # Accumulate text, token_ids, and logprobs for streaming mode
delta_output = output.outputs[0] delta_output = output.outputs[0]
...@@ -308,11 +311,14 @@ class ParsableContext(ConversationContext): ...@@ -308,11 +311,14 @@ class ParsableContext(ConversationContext):
self.input_messages: list[ResponseRawMessageAndToken] = [] self.input_messages: list[ResponseRawMessageAndToken] = []
self.output_messages: list[ResponseRawMessageAndToken] = [] self.output_messages: list[ResponseRawMessageAndToken] = []
self._accumulated_token_ids: list[int] = [] self._accumulated_token_ids: list[int] = []
self.kv_transfer_params: dict[str, Any] | None = None
def append_output(self, output: RequestOutput) -> None: def append_output(self, output: RequestOutput) -> None:
self.num_prompt_tokens = len(output.prompt_token_ids or []) self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0 self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or []) self.num_output_tokens += len(output.outputs[0].token_ids or [])
if output.kv_transfer_params is not None:
self.kv_transfer_params = output.kv_transfer_params
self.parser.process(output.outputs[0]) self.parser.process(output.outputs[0])
output_token_ids = output.outputs[0].token_ids or [] output_token_ids = output.outputs[0].token_ids or []
self._accumulated_token_ids.extend(output_token_ids) self._accumulated_token_ids.extend(output_token_ids)
...@@ -538,6 +544,7 @@ class HarmonyContext(ConversationContext): ...@@ -538,6 +544,7 @@ class HarmonyContext(ConversationContext):
self.all_turn_metrics: list[TurnMetrics] = [] self.all_turn_metrics: list[TurnMetrics] = []
self.is_first_turn = True self.is_first_turn = True
self.first_tok_of_message = True # For streaming support self.first_tok_of_message = True # For streaming support
self.kv_transfer_params: dict[str, Any] | None = None
def _update_num_reasoning_tokens(self): def _update_num_reasoning_tokens(self):
channel = self.parser.current_channel channel = self.parser.current_channel
...@@ -557,6 +564,8 @@ class HarmonyContext(ConversationContext): ...@@ -557,6 +564,8 @@ class HarmonyContext(ConversationContext):
self._update_num_reasoning_tokens() self._update_num_reasoning_tokens()
self._update_prefill_token_usage(output) self._update_prefill_token_usage(output)
self._update_decode_token_usage(output) self._update_decode_token_usage(output)
if output.kv_transfer_params is not None:
self.kv_transfer_params = output.kv_transfer_params
# Append current turn to all turn list for next turn's calculations # Append current turn to all turn list for next turn's calculations
self.all_turn_metrics.append(self.current_turn_metrics.copy()) self.all_turn_metrics.append(self.current_turn_metrics.copy())
self.current_turn_metrics.reset() self.current_turn_metrics.reset()
...@@ -868,6 +877,8 @@ class StreamingHarmonyContext(HarmonyContext): ...@@ -868,6 +877,8 @@ class StreamingHarmonyContext(HarmonyContext):
if last_delta_text: if last_delta_text:
self.last_content_delta = last_delta_text self.last_content_delta = last_delta_text
self._update_decode_token_usage(output) self._update_decode_token_usage(output)
if output.kv_transfer_params is not None:
self.kv_transfer_params = output.kv_transfer_params
# For streaming, update previous turn when message is complete # For streaming, update previous turn when message is complete
if output.finished: if output.finished:
......
...@@ -252,6 +252,10 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -252,6 +252,10 @@ class ResponsesRequest(OpenAIBaseModel):
"numeric values, used by custom extensions." "numeric values, used by custom extensions."
), ),
) )
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
# --8<-- [end:responses-extra-params] # --8<-- [end:responses-extra-params]
def build_chat_params( def build_chat_params(
...@@ -351,6 +355,10 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -351,6 +355,10 @@ class ResponsesRequest(OpenAIBaseModel):
if isinstance(stop, str): if isinstance(stop, str):
stop = [stop] stop = [stop]
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional( return SamplingParams.from_optional(
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
...@@ -367,7 +375,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -367,7 +375,7 @@ class ResponsesRequest(OpenAIBaseModel):
), ),
structured_outputs=structured_outputs, structured_outputs=structured_outputs,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
extra_args=self.vllm_xargs or {}, extra_args=extra_args,
skip_clone=True, # Created fresh per request, safe to skip clone skip_clone=True, # Created fresh per request, safe to skip clone
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
...@@ -488,6 +496,11 @@ class ResponsesResponse(OpenAIBaseModel): ...@@ -488,6 +496,11 @@ class ResponsesResponse(OpenAIBaseModel):
usage: ResponseUsage | None = None usage: ResponseUsage | None = None
user: str | None = None user: str | None = None
# vLLM-specific fields that are not in OpenAI spec
kv_transfer_params: dict[str, Any] | None = Field(
default=None, description="KVTransfer parameters."
)
# --8<-- [start:responses-response-extra-params] # --8<-- [start:responses-response-extra-params]
# These are populated when enable_response_messages is set to True # These are populated when enable_response_messages is set to True
# NOTE: custom serialization is needed # NOTE: custom serialization is needed
...@@ -531,6 +544,7 @@ class ResponsesResponse(OpenAIBaseModel): ...@@ -531,6 +544,7 @@ class ResponsesResponse(OpenAIBaseModel):
usage: ResponseUsage | None = None, usage: ResponseUsage | None = None,
input_messages: ResponseInputOutputMessage | None = None, input_messages: ResponseInputOutputMessage | None = None,
output_messages: ResponseInputOutputMessage | None = None, output_messages: ResponseInputOutputMessage | None = None,
kv_transfer_params: dict[str, Any] | None = None,
) -> "ResponsesResponse": ) -> "ResponsesResponse":
incomplete_details: IncompleteDetails | None = None incomplete_details: IncompleteDetails | None = None
if status == "incomplete": if status == "incomplete":
...@@ -566,6 +580,7 @@ class ResponsesResponse(OpenAIBaseModel): ...@@ -566,6 +580,7 @@ class ResponsesResponse(OpenAIBaseModel):
truncation=request.truncation, truncation=request.truncation,
user=request.user, user=request.user,
usage=usage, usage=usage,
kv_transfer_params=kv_transfer_params,
) )
......
...@@ -873,6 +873,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -873,6 +873,7 @@ class OpenAIServingResponses(OpenAIServing):
output=output, output=output,
status=status, status=status,
usage=usage, usage=usage,
kv_transfer_params=context.kv_transfer_params,
) )
if request.store: if request.store:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment