Unverified Commit 03c039c4 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

[OAI] patch origin request_id logic (#7508)

parent 57ab7769
...@@ -196,6 +196,9 @@ class CompletionRequest(BaseModel): ...@@ -196,6 +196,9 @@ class CompletionRequest(BaseModel):
bootstrap_port: Optional[int] = None bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None bootstrap_room: Optional[int] = None
# For request id
rid: Optional[Union[List[str], str]] = None
@field_validator("max_tokens") @field_validator("max_tokens")
@classmethod @classmethod
def validate_max_tokens_positive(cls, v): def validate_max_tokens_positive(cls, v):
...@@ -430,8 +433,8 @@ class ChatCompletionRequest(BaseModel): ...@@ -430,8 +433,8 @@ class ChatCompletionRequest(BaseModel):
stream_reasoning: bool = True stream_reasoning: bool = True
chat_template_kwargs: Optional[Dict] = None chat_template_kwargs: Optional[Dict] = None
# The request id. # For request id
rid: Optional[str] = None rid: Optional[Union[List[str], str]] = None
# For PD disaggregation # For PD disaggregation
bootstrap_host: Optional[str] = None bootstrap_host: Optional[str] = None
...@@ -529,7 +532,7 @@ class EmbeddingRequest(BaseModel): ...@@ -529,7 +532,7 @@ class EmbeddingRequest(BaseModel):
user: Optional[str] = None user: Optional[str] = None
# The request id. # The request id.
rid: Optional[str] = None rid: Optional[Union[List[str], str]] = None
class EmbeddingObject(BaseModel): class EmbeddingObject(BaseModel):
......
...@@ -95,6 +95,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -95,6 +95,7 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_port=request.bootstrap_port, bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states, return_hidden_states=request.return_hidden_states,
rid=request.rid,
) )
return adapted_request, request return adapted_request, request
......
...@@ -87,6 +87,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -87,6 +87,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_port=request.bootstrap_port, bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states, return_hidden_states=request.return_hidden_states,
rid=request.rid,
) )
return adapted_request, request return adapted_request, request
......
...@@ -119,6 +119,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -119,6 +119,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
adapted_request = EmbeddingReqInput( adapted_request = EmbeddingReqInput(
**prompt_kwargs, **prompt_kwargs,
rid=request.rid,
) )
return adapted_request, request return adapted_request, request
......
...@@ -319,8 +319,16 @@ class GenerateReqInput: ...@@ -319,8 +319,16 @@ class GenerateReqInput:
"""Normalize request IDs for batch processing.""" """Normalize request IDs for batch processing."""
if self.rid is None: if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)] self.rid = [uuid.uuid4().hex for _ in range(num)]
elif not isinstance(self.rid, list): elif isinstance(self.rid, str):
raise ValueError("The rid should be a list for batch processing.") new_rids = [f"{self.rid}_{i}" for i in range(num)]
self.rid = new_rids
elif isinstance(self.rid, list):
if len(self.rid) != num:
raise ValueError(
"The specified rids length mismatch with the batch_size for batch processing."
)
else:
raise ValueError("The rid should be a string or a list of strings.")
def _normalize_logprob_params(self, num): def _normalize_logprob_params(self, num):
"""Normalize logprob-related parameters for batch processing.""" """Normalize logprob-related parameters for batch processing."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment