"vscode:/vscode.git/clone" did not exist on "beae085a7b8281d828f1cfb6cfeb4506a9d0cc91"
Unverified Commit bc5ef333 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Perf] Add skip_clone to SamplingParams for internal request handling (#31041)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 09dc7c69
...@@ -60,7 +60,8 @@ async def generate(request: Request) -> Response: ...@@ -60,7 +60,8 @@ async def generate(request: Request) -> Response:
async def _generate(request_dict: dict, raw_request: Request) -> Response: async def _generate(request_dict: dict, raw_request: Request) -> Response:
prompt = request_dict.pop("prompt") prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False) stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict) # Since SamplingParams is created fresh per request, safe to skip clone
sampling_params = SamplingParams(**request_dict, skip_clone=True)
request_id = random_uuid() request_id = random_uuid()
assert engine is not None assert engine is not None
......
...@@ -642,7 +642,10 @@ class LLM: ...@@ -642,7 +642,10 @@ class LLM:
# following the huggingface transformers implementation # following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams( beam_search_params = SamplingParams(
logprobs=2 * beam_width, max_tokens=1, temperature=temperature logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
skip_clone=True, # Internal beam search, safe to skip clone
) )
instances: list[BeamSearchInstance] = [] instances: list[BeamSearchInstance] = []
......
...@@ -474,6 +474,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -474,6 +474,7 @@ class ResponsesRequest(OpenAIBaseModel):
), ),
structured_outputs=structured_outputs, structured_outputs=structured_outputs,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
skip_clone=True, # Created fresh per request, safe to skip clone
) )
def is_include_output_logprobs(self) -> bool: def is_include_output_logprobs(self) -> bool:
...@@ -876,6 +877,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -876,6 +877,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
bad_words=self.bad_words, bad_words=self.bad_words,
allowed_token_ids=self.allowed_token_ids, allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None, extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
) )
@model_validator(mode="before") @model_validator(mode="before")
...@@ -1316,6 +1318,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -1316,6 +1318,7 @@ class CompletionRequest(OpenAIBaseModel):
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids, allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None, extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
) )
@model_validator(mode="before") @model_validator(mode="before")
...@@ -2182,6 +2185,7 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -2182,6 +2185,7 @@ class TranscriptionRequest(OpenAIBaseModel):
if self.stream if self.stream
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
extra_args=self.vllm_xargs, extra_args=self.vllm_xargs,
skip_clone=True, # Created fresh per request, safe to skip clone
) )
@model_validator(mode="before") @model_validator(mode="before")
...@@ -2409,6 +2413,7 @@ class TranslationRequest(OpenAIBaseModel): ...@@ -2409,6 +2413,7 @@ class TranslationRequest(OpenAIBaseModel):
output_kind=RequestOutputKind.DELTA output_kind=RequestOutputKind.DELTA
if self.stream if self.stream
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
skip_clone=True, # Created fresh per request, safe to skip clone
) )
@model_validator(mode="before") @model_validator(mode="before")
......
...@@ -219,6 +219,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -219,6 +219,7 @@ class OpenAISpeechToText(OpenAIServing):
dummy_params = SamplingParams( dummy_params = SamplingParams(
max_tokens=1, max_tokens=1,
temperature=0.0, temperature=0.0,
skip_clone=True, # Internal warmup, safe to skip clone
) )
# Process the dummy input through the input processor # Process the dummy input through the input processor
......
...@@ -211,6 +211,12 @@ class SamplingParams( ...@@ -211,6 +211,12 @@ class SamplingParams(
set to an integer k, will use only the last k tokens from the prompt set to an integer k, will use only the last k tokens from the prompt
(i.e., left truncation). If set to `None`, truncation is disabled.""" (i.e., left truncation). If set to `None`, truncation is disabled."""
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
skip_clone: bool = False
"""Internal flag indicating that this SamplingParams instance is safe to
reuse without cloning. When True, clone() will return self without
performing a deep copy. This should only be set when the params object
is guaranteed to be dedicated to a single request and won't be modified
in ways that would affect other uses."""
# The below fields are not supposed to be used as an input. # The below fields are not supposed to be used as an input.
# They are set in post_init. # They are set in post_init.
...@@ -270,6 +276,7 @@ class SamplingParams( ...@@ -270,6 +276,7 @@ class SamplingParams(
logit_bias: dict[int, float] | dict[str, float] | None = None, logit_bias: dict[int, float] | dict[str, float] | None = None,
allowed_token_ids: list[int] | None = None, allowed_token_ids: list[int] | None = None,
extra_args: dict[str, Any] | None = None, extra_args: dict[str, Any] | None = None,
skip_clone: bool = False,
) -> "SamplingParams": ) -> "SamplingParams":
if logit_bias is not None: if logit_bias is not None:
# Convert token_id to integer # Convert token_id to integer
...@@ -310,6 +317,7 @@ class SamplingParams( ...@@ -310,6 +317,7 @@ class SamplingParams(
logit_bias=logit_bias, logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids, allowed_token_ids=allowed_token_ids,
extra_args=extra_args, extra_args=extra_args,
skip_clone=skip_clone,
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
...@@ -540,8 +548,13 @@ class SamplingParams( ...@@ -540,8 +548,13 @@ class SamplingParams(
data that is expensive to copy. However, if not copied, the processor data that is expensive to copy. However, if not copied, the processor
needs to support parallel decoding for multiple sequences needs to support parallel decoding for multiple sequences
See https://github.com/vllm-project/vllm/issues/3087 See https://github.com/vllm-project/vllm/issues/3087
If skip_clone is True, uses shallow copy instead of deep copy.
""" """
if self.skip_clone:
return copy.copy(self)
logit_processor_refs = ( logit_processor_refs = (
None None
if self.logits_processors is None if self.logits_processors is None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment