Unverified Commit 94e01151 authored by Aidan Cooper's avatar Aidan Cooper Committed by GitHub
Browse files

Feat: add alternative choices selection methods (#835)

parent b216a545
# Choices Methods in SGLang
This doc describes the choices methods supported by SGLang.
The optional `choices_method` arg determines how options supplied to SGLang's `choices` primitive are selected. Only the `RuntimeEndpoint` backend supports the `choices_method` arg. Other backends, such as `OpenAI`, have bespoke selection implementations due to API limitations.
## Methods
### Token Length Normalized
Token length normalized is the default SGLang choices method. It selects the option with the highest average logprob across all of its tokens.
Usage example (alternatively, simply omit the `choices_method` arg):
```python
@sgl.function
def example(s):
s += sgl.user("What is the capital of France?")
s += sgl.assistant(
sgl.gen(
"answer",
choices=["London", "Paris", "Berlin"],
choices_method=sgl.token_length_normalized,
)
)
```
This can perform poorly if an option contains many tokens, where its later tokens are predicted with high confidence based on its earlier tokens. For instance, even strong models will fail the above example if the specified options are `["Paris", "Antidisestablishmentarianism"]`.
### Greedy Token Selection
Greedy token selection simply selects the option with the highest logprob for its initial token. For overlapping options where one option is a subset of a longer option, the logprobs of the shorter option are extended using its average logprob for comparison against the longer option.
Usage example:
```python
@sgl.function
def example(s):
s += sgl.user("What is the capital of France?")
s += sgl.assistant(
sgl.gen(
"answer",
choices=["London", "Paris", "Berlin"],
choices_method=sgl.greedy_token_selection,
)
)
```
This can perform poorly if an option misleads the model down a bad path based on an attractive initial token. For instance, greedy selection will result in an incorrect response for this example:
```python
@sgl.function
def us_president_example(s):
s += sgl.user("Name a US president.")
s += sgl.assistant(
sgl.gen(
"answer",
choices=["Donald Duck", "Millard Fillmore"],
choices_method=sgl.greedy_token_selection,
)
)
```
### Unconditional Likelihood Normalized
Unconditional likelihood normalized selects the option with the highest average token logprob once normalized by the unconditional token logprobs, as described in [this EleutherAI blogpost](https://blog.eleuther.ai/multiple-choice-normalization/). This method incurs an additional LLM call to obtain the unconditional likelihoods.
Usage example:
```python
@sgl.function
def example(s):
s += sgl.user("What is the capital of France?")
s += sgl.assistant(
sgl.gen(
"answer",
choices=["London", "Paris", "Berlin"],
choices_method=sgl.unconditional_likelihood_normalized,
)
)
```
\ No newline at end of file
......@@ -22,6 +22,11 @@ from sglang.api import (
user_end,
video,
)
from sglang.lang.choices import (
greedy_token_selection,
token_length_normalized,
unconditional_likelihood_normalized,
)
# SGLang DSL APIs
__all__ = [
......@@ -45,6 +50,9 @@ __all__ = [
"user_begin",
"user_end",
"video",
"greedy_token_selection",
"token_length_normalized",
"unconditional_likelihood_normalized",
]
# Global Configurations
......
......@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
from sglang.lang.ir import (
SglExpr,
SglExprList,
......@@ -73,12 +74,18 @@ def gen(
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None,
choices: Optional[List[str]] = None,
choices_method: Optional[ChoicesSamplingMethod] = None,
regex: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
if choices:
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
return SglSelect(
name,
choices,
0.0 if temperature is None else temperature,
token_length_normalized if choices_method is None else choices_method,
)
# check regex is valid
if regex is not None:
......@@ -186,9 +193,10 @@ def select(
name: Optional[str] = None,
choices: Optional[List[str]] = None,
temperature: float = 0.0,
choices_method: ChoicesSamplingMethod = token_length_normalized,
):
assert choices is not None
return SglSelect(name, choices, temperature)
return SglSelect(name, choices, temperature, choices_method)
def _role_common(name: str, expr: Optional[SglExpr] = None):
......
from typing import Callable, List, Optional, Union
from sglang.lang.chat_template import get_chat_template
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
......@@ -64,7 +65,8 @@ class BaseBackend:
s: StreamExecutor,
choices: List[str],
temperature: float,
):
choices_method: Optional[ChoicesSamplingMethod] = None,
) -> ChoicesDecision:
raise NotImplementedError()
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
......
......@@ -8,6 +8,7 @@ import numpy as np
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
......@@ -296,7 +297,9 @@ class OpenAI(BaseBackend):
s: StreamExecutor,
choices: List[str],
temperature: float,
):
choices_method: ChoicesSamplingMethod,
) -> ChoicesDecision:
"""Note: `choices_method` is not used by the OpenAI backend."""
if self.is_chat_model:
raise NotImplementedError(
"select/choices is not supported for chat models. "
......@@ -354,8 +357,10 @@ class OpenAI(BaseBackend):
prompt_tokens.append(ret_token)
decision = choices[np.argmax(scores)]
return decision, scores, None, None
return ChoicesDecision(
decision=choices[np.argmax(scores)],
meta_info={"scores": scores},
)
def openai_completion(
......
import json
from typing import List, Optional
import numpy as np
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.choices import (
ChoicesDecision,
ChoicesSamplingMethod,
token_length_normalized,
)
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
from sglang.utils import http_request
class RuntimeEndpoint(BaseBackend):
def __init__(
self,
base_url: str,
......@@ -208,20 +212,14 @@ class RuntimeEndpoint(BaseBackend):
s: StreamExecutor,
choices: List[str],
temperature: float,
):
choices_method: ChoicesSamplingMethod,
) -> ChoicesDecision:
assert temperature <= 1e-5
# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
prompt_len = res.json()["meta_info"]["prompt_tokens"]
obj = self._generate_http_request(s, data)
prompt_len = obj["meta_info"]["prompt_tokens"]
# Compute logprob
data = {
......@@ -230,27 +228,35 @@ class RuntimeEndpoint(BaseBackend):
"return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0),
}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
obj = res.json()
obj = self._generate_http_request(s, data)
normalized_prompt_logprobs = [
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
decision = choices[np.argmax(normalized_prompt_logprobs)]
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
return (
decision,
normalized_prompt_logprobs,
input_token_logprobs,
output_token_logprobs,
# Compute unconditional logprobs if required
if choices_method.requires_unconditional_logprobs:
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
data = {
"input_ids": input_ids,
"sampling_params": {"max_new_tokens": 0},
"return_logprob": True,
}
obj = self._generate_http_request(s, data)
unconditional_token_logprobs = [
r["meta_info"]["input_token_logprobs"] for r in obj
]
else:
unconditional_token_logprobs = None
return choices_method(
choices=choices,
normalized_prompt_logprobs=normalized_prompt_logprobs,
input_token_logprobs=input_token_logprobs,
output_token_logprobs=output_token_logprobs,
unconditional_token_logprobs=unconditional_token_logprobs,
)
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
......@@ -262,6 +268,17 @@ class RuntimeEndpoint(BaseBackend):
)
self._assert_success(res)
def _generate_http_request(self, s: StreamExecutor, data):
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
return res.json()
def _add_images(self, s: StreamExecutor, data):
if s.images_:
assert len(s.images_) == 1, "Only support one image."
......
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import numpy as np
@dataclass
class ChoicesDecision:
decision: str
meta_info: Dict[str, Any] | None = None
class ChoicesSamplingMethod(ABC):
@property
def requires_unconditional_logprobs(self) -> bool:
return False
@abstractmethod
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision: ...
class TokenLengthNormalized(ChoicesSamplingMethod):
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision:
"""Select the option with the highest token length normalized prompt logprob."""
best_choice = choices[np.argmax(normalized_prompt_logprobs)]
meta_info = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
token_length_normalized = TokenLengthNormalized()
class GreedyTokenSelection(ChoicesSamplingMethod):
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision:
"""Select the option based on greedy logprob selection. For overlapping options
where one option is a subset of a longer option, extend the shorter option using
its average logprob for comparison against the longer option."""
num_options = len(choices)
max_tokens = max(len(option) for option in input_token_logprobs)
logprob_matrix = self._build_logprob_matrix(
input_token_logprobs, max_tokens, num_options
)
remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)
best_choice = choices[remaining[0]]
meta_info = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
"greedy_logprob_matrix": logprob_matrix.tolist(),
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):
logprob_matrix = np.zeros((num_options, max_tokens))
for i, option in enumerate(input_token_logprobs):
actual_logprobs = [token[0] for token in option]
avg_logprob = np.mean(actual_logprobs)
logprob_matrix[i, : len(option)] = actual_logprobs
if len(option) < max_tokens:
logprob_matrix[i, len(option) :] = avg_logprob
return logprob_matrix
def _greedy_selection(self, logprob_matrix, num_options, max_tokens):
remaining = np.arange(num_options)
for j in range(max_tokens):
max_logprob = np.max(logprob_matrix[remaining, j])
remaining = remaining[logprob_matrix[remaining, j] == max_logprob]
if len(remaining) == 1:
break
return remaining
greedy_token_selection = GreedyTokenSelection()
class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod):
@property
def requires_unconditional_logprobs(self) -> bool:
return True
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision:
"""Select the option with the highest average token logprob once normalized by
the unconditional token logprobs.
The first unconditional token logprob is assumed to be None. If so, it is
replaced with 0 for the purposes of normalization."""
if unconditional_token_logprobs is None:
raise ValueError(
"Unconditional token logprobs are required for this method."
)
normalized_unconditional_prompt_logprobs = self._normalize_logprobs(
input_token_logprobs, unconditional_token_logprobs
)
best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)]
meta_info = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
"unconditional_token_logprobs": unconditional_token_logprobs,
"normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs,
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs):
normalized_unconditional_prompt_logprobs = []
for inputs, unconditionals in zip(
input_token_logprobs, unconditional_token_logprobs
):
inputs_logprobs = np.array([token[0] for token in inputs])
unconditionals_logprobs = np.array([token[0] for token in unconditionals])
unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0
normalized_unconditional_prompt_logprobs.append(
float(np.mean(inputs_logprobs - unconditionals_logprobs))
)
return normalized_unconditional_prompt_logprobs
unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized()
......@@ -538,24 +538,17 @@ class StreamExecutor:
self.stream_var_event[name].set()
def _execute_select(self, expr: SglSelect):
(
decision,
normalized_prompt_logprobs,
input_token_logprobs,
output_token_logprobs,
) = self.backend.select(self, expr.choices, expr.temperature)
choices_decision = self.backend.select(
self, expr.choices, expr.temperature, expr.choices_method
)
if expr.name is not None:
name = expr.name
self.variables[name] = decision
self.meta_info[name] = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
}
self.variables[name] = choices_decision.decision
self.meta_info[name] = choices_decision.meta_info
self.variable_event[name].set()
if self.stream_var_event:
self.stream_var_event[name].set()
self.text_ += decision
self.text_ += choices_decision.decision
def _execute_variable(self, expr: SglVariable):
src_executor = expr.source_stream_executor
......
......@@ -6,6 +6,7 @@ import warnings
from typing import List, Optional, Union
from sglang.global_config import global_config
from sglang.lang.choices import ChoicesSamplingMethod
REGEX_INT = r"[-+]?[0-9]+"
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
......@@ -461,14 +462,22 @@ class SglRoleEnd(SglExpr):
class SglSelect(SglExpr):
def __init__(self, name: str, choices: List[str], temperature: float):
def __init__(
self,
name: str,
choices: List[str],
temperature: float,
choices_method: ChoicesSamplingMethod,
):
super().__init__()
self.name = name
self.choices = choices
self.temperature = temperature
self.choices_method = choices_method
def __repr__(self):
return f"Select({self.name}, choices={self.choices})"
return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})"
class SglFork(SglExpr):
......
import unittest
import numpy as np
from sglang.lang.choices import (
greedy_token_selection,
token_length_normalized,
unconditional_likelihood_normalized,
)
MOCK_CHOICES_INPUT_DATA = {
"choices": [
"organ", # ["organ"]
"organism", # ["organ", "ism"]
"antidisestablishmentarianism", # ["ant", "id", "is", "est", "ablish", "ment", "arian", "ism"]
],
"normalized_prompt_logprobs": [-0.1, -0.2, -0.05],
"input_token_logprobs": [
[[-0.1, 1, None]],
[[-0.1, 1, None], [-0.3, 2, None]],
[
[-0.4, 3, None],
[-0.25, 4, None],
[-0.1, 5, None],
[-0.01, 6, None],
[-0.01, 7, None],
[-0.01, 8, None],
[-0.01, 9, None],
[-0.01, 2, None],
],
],
"output_token_logprobs": [
[[-0.1, 10, None]],
[[-0.1, 10, None]],
[[-0.1, 10, None]],
],
"unconditional_token_logprobs": [
[[None, 1, None]],
[[None, 1, None], [-1.4, 2, None]],
[
[None, 3, None],
[-0.25, 4, None],
[-0.1, 5, None],
[-0.01, 6, None],
[-0.01, 7, None],
[-0.01, 8, None],
[-0.01, 9, None],
[-0.01, 2, None],
],
],
}
class TestChoices(unittest.TestCase):
def test_token_length_normalized(self):
"""Confirm 'antidisestablishmentarianism' is selected due to high confidences for
its later tokens resulting in highest token length normalized prompt logprob."""
decision = token_length_normalized(**MOCK_CHOICES_INPUT_DATA)
assert decision.decision == "antidisestablishmentarianism"
def test_greedy_token_selection(self):
"""Confirm 'organ' is selected due it having the joint highest initial token
logprob, and a higher average logprob than organism's second token."""
decision = greedy_token_selection(**MOCK_CHOICES_INPUT_DATA)
assert decision.decision == "organ"
assert np.allclose(
decision.meta_info["greedy_logprob_matrix"],
[
[-0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1],
[-0.1, -0.3, -0.2, -0.2, -0.2, -0.2, -0.2, -0.2],
[-0.4, -0.25, -0.1, -0.01, -0.01, -0.01, -0.01, -0.01],
],
atol=0.01,
)
def test_unconditional_likelihood_normalized(self):
"""Confirm 'organism' is selected due to it having the highest average token logprob
once normalized by the unconditional token logprobs."""
decision = unconditional_likelihood_normalized(**MOCK_CHOICES_INPUT_DATA)
assert decision.decision == "organism"
assert np.allclose(
decision.meta_info["normalized_unconditional_prompt_logprobs"],
[-0.1, 0.5, -0.05],
atol=0.01,
)
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestChoices()
# t.test_token_length_normalized()
# t.test_greedy_token_selection()
# t.test_unconditional_likelihood_normalized()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment