Unverified Commit 4aae6676 authored by Aviv Keshet's avatar Aviv Keshet Committed by GitHub
Browse files

[core] add `extra_args` to `SamplingParams` (#13300)


Signed-off-by: default avatarAviv Keshet <akeshet@scaledcognition.com>
parent 9f3bc0f5
...@@ -184,6 +184,9 @@ class SamplingParams( ...@@ -184,6 +184,9 @@ class SamplingParams(
allowed_token_ids: If provided, the engine will construct a logits allowed_token_ids: If provided, the engine will construct a logits
processor which only retains scores for the given token ids. processor which only retains scores for the given token ids.
Defaults to None. Defaults to None.
extra_args: Arbitrary additional args, that can be used by custom
sampling implementations. Not used by any in-tree sampling
implementations.
""" """
n: int = 1 n: int = 1
...@@ -227,6 +230,7 @@ class SamplingParams( ...@@ -227,6 +230,7 @@ class SamplingParams(
guided_decoding: Optional[GuidedDecodingParams] = None guided_decoding: Optional[GuidedDecodingParams] = None
logit_bias: Optional[dict[int, float]] = None logit_bias: Optional[dict[int, float]] = None
allowed_token_ids: Optional[list[int]] = None allowed_token_ids: Optional[list[int]] = None
extra_args: Optional[dict[str, Any]] = None
@staticmethod @staticmethod
def from_optional( def from_optional(
...@@ -259,6 +263,7 @@ class SamplingParams( ...@@ -259,6 +263,7 @@ class SamplingParams(
guided_decoding: Optional[GuidedDecodingParams] = None, guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
allowed_token_ids: Optional[list[int]] = None, allowed_token_ids: Optional[list[int]] = None,
extra_args: Optional[dict[str, Any]] = None,
) -> "SamplingParams": ) -> "SamplingParams":
if logit_bias is not None: if logit_bias is not None:
# Convert token_id to integer # Convert token_id to integer
...@@ -300,6 +305,7 @@ class SamplingParams( ...@@ -300,6 +305,7 @@ class SamplingParams(
guided_decoding=guided_decoding, guided_decoding=guided_decoding,
logit_bias=logit_bias, logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids, allowed_token_ids=allowed_token_ids,
extra_args=extra_args,
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
...@@ -509,7 +515,8 @@ class SamplingParams( ...@@ -509,7 +515,8 @@ class SamplingParams(
"spaces_between_special_tokens=" "spaces_between_special_tokens="
f"{self.spaces_between_special_tokens}, " f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, " f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
f"guided_decoding={self.guided_decoding})") f"guided_decoding={self.guided_decoding}, "
f"extra_args={self.extra_args})")
class BeamSearchParams( class BeamSearchParams(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment