Unverified Commit fad86a68 authored by Chuyue Sun's avatar Chuyue Sun Committed by GitHub
Browse files

Support `n` in OpenAI API completions (#3446)


Co-authored-by: default avatarShan Yu <shanyu1@g.ucla.edu>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
Co-authored-by: default avatarchuyue sun <chuyue@lmsys.us-northcentral1-a.compute.internal>
parent df7014a8
"""
Usage:
export OPENAI_API_KEY=sk-******
python3 openai_example_chat.py
"""
import json
import sglang as sgl
@sgl.function
def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1", max_tokens=1024, n=2))
s += sgl.user(question_2)
s += sgl.assistant(
sgl.gen(
"answer_2",
max_tokens=1024,
)
)
def single():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- answer_1 --\n", state["answer_1"])
print("\n-- answer_2 --\n", state["answer_2"])
assert isinstance(state["answer_1"], list)
assert len(state["answer_1"]) == 2
assert isinstance(state["answer_2"], str)
def batch():
states = multi_turn_question.run_batch(
[
{
"question_1": "What is the capital of the United States?",
"question_2": "List two local attractions.",
},
{
"question_1": "What is the capital of France?",
"question_2": "What is the population of this city?",
},
]
)
for s in states:
print(s.messages())
print("\n-- answer_1 --\n", s["answer_1"])
print("\n-- answer_2 --\n", s["answer_2"])
assert isinstance(s["answer_1"], list)
assert len(s["answer_1"]) == 2
assert isinstance(s["answer_2"], str)
if __name__ == "__main__":
sgl.set_default_backend(sgl.OpenAI("o1"))
# Run a single request
print("\n========== single ==========\n")
single()
# Run a batch of requests
print("\n========== batch ==========\n")
batch()
......@@ -75,6 +75,7 @@ def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
......@@ -115,6 +116,7 @@ def gen(
name,
max_tokens,
min_tokens,
n,
stop,
stop_token_ids,
temperature,
......@@ -137,6 +139,7 @@ def gen(
def gen_int(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
......@@ -155,6 +158,7 @@ def gen_int(
name,
max_tokens,
None,
n,
stop,
stop_token_ids,
temperature,
......@@ -176,6 +180,7 @@ def gen_int(
def gen_string(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
......@@ -194,6 +199,7 @@ def gen_string(
name,
max_tokens,
None,
n,
stop,
stop_token_ids,
temperature,
......
......@@ -165,6 +165,7 @@ class OpenAI(BaseBackend):
kwargs.pop("max_tokens", None)
else:
kwargs.pop("max_completion_tokens", None)
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
......@@ -173,13 +174,13 @@ class OpenAI(BaseBackend):
prompt=prompt,
**kwargs,
)
# Keep the returned list (or string) as is.
elif sampling_params.dtype in [str, "str", "string"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
......@@ -189,7 +190,11 @@ class OpenAI(BaseBackend):
stop='"',
**kwargs,
)
comp = '"' + comp + '"'
# Wrap each element in quotes if we have a list.
if isinstance(comp, list):
comp = ['"' + x + '"' for x in comp]
else:
comp = '"' + comp + '"'
elif sampling_params.dtype in [int, "int"]:
assert (
not self.is_chat_model
......@@ -206,6 +211,7 @@ class OpenAI(BaseBackend):
stop=[" "],
**kwargs,
)
# Leave as a list if that's what is returned.
else:
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
......@@ -254,7 +260,9 @@ class OpenAI(BaseBackend):
prompt=s.messages_,
**self.spec_kwargs,
)
if self.spec_pattern_match(comp):
# Use a string for pattern matching.
comp_for_match = comp[0] if isinstance(comp, list) else comp
if self.spec_pattern_match(comp_for_match):
break
for term in self.spec_format:
......@@ -370,7 +378,7 @@ class OpenAI(BaseBackend):
def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
) -> Union[str, List[str]]:
# if "ebnf" is in kwargs, warn and remove
if "ebnf" in kwargs:
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
......@@ -382,13 +390,18 @@ def openai_completion(
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
ret = client.chat.completions.create(messages=prompt, **kwargs)
comp = ret.choices[0].message.content
if len(ret.choices) == 1:
comp = ret.choices[0].message.content
else:
comp = [c.message.content for c in ret.choices]
else:
ret = client.completions.create(prompt=prompt, **kwargs)
if isinstance(prompt, (list, tuple)):
comp = [c.text for c in ret.choices]
else:
comp = ret.choices[0].text
if len(ret.choices) > 1:
comp = [c.text for c in ret.choices]
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
......
......@@ -566,13 +566,13 @@ class StreamExecutor:
def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name
if not self.stream:
if self.num_api_spec_tokens is None:
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
)
else:
if self.backend.is_chat_model:
# Speculative execution on models with only chat interface.
......@@ -587,8 +587,11 @@ class StreamExecutor:
else: # Speculative execution on models with completion interface
comp, meta_info = self._spec_gen(sampling_params)
self.text_ += comp
if isinstance(comp, list):
self.text_ += comp[0]
else:
assert isinstance(comp, str)
self.text_ += comp
self.variables[name] = comp
self.meta_info[name] = meta_info
......@@ -747,6 +750,7 @@ class StreamExecutor:
for item in [
"max_new_tokens",
"min_new_tokens",
"n",
"stop",
"stop_token_ids",
"temperature",
......
......@@ -18,6 +18,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
class SglSamplingParams:
max_new_tokens: int = 128
min_new_tokens: int = 0
n: int = 1
stop: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = ()
temperature: float = 1.0
......@@ -41,6 +42,7 @@ class SglSamplingParams:
return SglSamplingParams(
self.max_new_tokens,
self.min_new_tokens,
self.n,
self.stop,
self.stop_token_ids,
self.temperature,
......@@ -64,6 +66,7 @@ class SglSamplingParams:
return {
"max_tokens": self.max_new_tokens,
"max_completion_tokens": self.max_new_tokens,
"n": self.n,
"stop": self.stop or None,
"temperature": self.temperature,
"top_p": self.top_p,
......@@ -117,6 +120,7 @@ class SglSamplingParams:
return {
"max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens,
"n": self.n,
"stop": self.stop,
"stop_token_ids": self.stop_token_ids,
"temperature": self.temperature,
......@@ -154,6 +158,7 @@ class SglFunction:
self,
*args,
max_new_tokens: int = 128,
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
......@@ -182,6 +187,7 @@ class SglFunction:
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
......@@ -212,6 +218,7 @@ class SglFunction:
batch_kwargs,
*,
max_new_tokens: int = 128,
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
......@@ -257,6 +264,7 @@ class SglFunction:
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
......@@ -440,6 +448,7 @@ class SglGen(SglExpr):
name: Optional[str] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
......@@ -463,6 +472,7 @@ class SglGen(SglExpr):
self.sampling_params = SglSamplingParams(
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment