Unverified Commit fad86a68 authored by Chuyue Sun's avatar Chuyue Sun Committed by GitHub
Browse files

Support `n` in OpenAI API completions (#3446)


Co-authored-by: default avatarShan Yu <shanyu1@g.ucla.edu>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
Co-authored-by: default avatarchuyue sun <chuyue@lmsys.us-northcentral1-a.compute.internal>
parent df7014a8
"""
Usage:
export OPENAI_API_KEY=sk-******
python3 openai_example_chat.py
"""
import json
import sglang as sgl
@sgl.function
def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1", max_tokens=1024, n=2))
s += sgl.user(question_2)
s += sgl.assistant(
sgl.gen(
"answer_2",
max_tokens=1024,
)
)
def single():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- answer_1 --\n", state["answer_1"])
print("\n-- answer_2 --\n", state["answer_2"])
assert isinstance(state["answer_1"], list)
assert len(state["answer_1"]) == 2
assert isinstance(state["answer_2"], str)
def batch():
states = multi_turn_question.run_batch(
[
{
"question_1": "What is the capital of the United States?",
"question_2": "List two local attractions.",
},
{
"question_1": "What is the capital of France?",
"question_2": "What is the population of this city?",
},
]
)
for s in states:
print(s.messages())
print("\n-- answer_1 --\n", s["answer_1"])
print("\n-- answer_2 --\n", s["answer_2"])
assert isinstance(s["answer_1"], list)
assert len(s["answer_1"]) == 2
assert isinstance(s["answer_2"], str)
if __name__ == "__main__":
sgl.set_default_backend(sgl.OpenAI("o1"))
# Run a single request
print("\n========== single ==========\n")
single()
# Run a batch of requests
print("\n========== batch ==========\n")
batch()
...@@ -75,6 +75,7 @@ def gen( ...@@ -75,6 +75,7 @@ def gen(
name: Optional[str] = None, name: Optional[str] = None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None, min_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
...@@ -115,6 +116,7 @@ def gen( ...@@ -115,6 +116,7 @@ def gen(
name, name,
max_tokens, max_tokens,
min_tokens, min_tokens,
n,
stop, stop,
stop_token_ids, stop_token_ids,
temperature, temperature,
...@@ -137,6 +139,7 @@ def gen( ...@@ -137,6 +139,7 @@ def gen(
def gen_int( def gen_int(
name: Optional[str] = None, name: Optional[str] = None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
...@@ -155,6 +158,7 @@ def gen_int( ...@@ -155,6 +158,7 @@ def gen_int(
name, name,
max_tokens, max_tokens,
None, None,
n,
stop, stop,
stop_token_ids, stop_token_ids,
temperature, temperature,
...@@ -176,6 +180,7 @@ def gen_int( ...@@ -176,6 +180,7 @@ def gen_int(
def gen_string( def gen_string(
name: Optional[str] = None, name: Optional[str] = None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
...@@ -194,6 +199,7 @@ def gen_string( ...@@ -194,6 +199,7 @@ def gen_string(
name, name,
max_tokens, max_tokens,
None, None,
n,
stop, stop,
stop_token_ids, stop_token_ids,
temperature, temperature,
......
...@@ -165,6 +165,7 @@ class OpenAI(BaseBackend): ...@@ -165,6 +165,7 @@ class OpenAI(BaseBackend):
kwargs.pop("max_tokens", None) kwargs.pop("max_tokens", None)
else: else:
kwargs.pop("max_completion_tokens", None) kwargs.pop("max_completion_tokens", None)
comp = openai_completion( comp = openai_completion(
client=self.client, client=self.client,
token_usage=self.token_usage, token_usage=self.token_usage,
...@@ -173,13 +174,13 @@ class OpenAI(BaseBackend): ...@@ -173,13 +174,13 @@ class OpenAI(BaseBackend):
prompt=prompt, prompt=prompt,
**kwargs, **kwargs,
) )
# Keep the returned list (or string) as is.
elif sampling_params.dtype in [str, "str", "string"]: elif sampling_params.dtype in [str, "str", "string"]:
assert ( assert (
not self.is_chat_model not self.is_chat_model
), "constrained type not supported on chat model" ), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs() kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop") kwargs.pop("stop")
comp = openai_completion( comp = openai_completion(
client=self.client, client=self.client,
token_usage=self.token_usage, token_usage=self.token_usage,
...@@ -189,7 +190,11 @@ class OpenAI(BaseBackend): ...@@ -189,7 +190,11 @@ class OpenAI(BaseBackend):
stop='"', stop='"',
**kwargs, **kwargs,
) )
comp = '"' + comp + '"' # Wrap each element in quotes if we have a list.
if isinstance(comp, list):
comp = ['"' + x + '"' for x in comp]
else:
comp = '"' + comp + '"'
elif sampling_params.dtype in [int, "int"]: elif sampling_params.dtype in [int, "int"]:
assert ( assert (
not self.is_chat_model not self.is_chat_model
...@@ -206,6 +211,7 @@ class OpenAI(BaseBackend): ...@@ -206,6 +211,7 @@ class OpenAI(BaseBackend):
stop=[" "], stop=[" "],
**kwargs, **kwargs,
) )
# Leave as a list if that's what is returned.
else: else:
raise ValueError(f"Unknown dtype: {sampling_params.dtype}") raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
...@@ -254,7 +260,9 @@ class OpenAI(BaseBackend): ...@@ -254,7 +260,9 @@ class OpenAI(BaseBackend):
prompt=s.messages_, prompt=s.messages_,
**self.spec_kwargs, **self.spec_kwargs,
) )
if self.spec_pattern_match(comp): # Use a string for pattern matching.
comp_for_match = comp[0] if isinstance(comp, list) else comp
if self.spec_pattern_match(comp_for_match):
break break
for term in self.spec_format: for term in self.spec_format:
...@@ -370,7 +378,7 @@ class OpenAI(BaseBackend): ...@@ -370,7 +378,7 @@ class OpenAI(BaseBackend):
def openai_completion( def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
): ) -> Union[str, List[str]]:
# if "ebnf" is in kwargs, warn and remove # if "ebnf" is in kwargs, warn and remove
if "ebnf" in kwargs: if "ebnf" in kwargs:
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.") warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
...@@ -382,13 +390,18 @@ def openai_completion( ...@@ -382,13 +390,18 @@ def openai_completion(
if "stop" in kwargs and kwargs["stop"] is None: if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop") kwargs.pop("stop")
ret = client.chat.completions.create(messages=prompt, **kwargs) ret = client.chat.completions.create(messages=prompt, **kwargs)
comp = ret.choices[0].message.content if len(ret.choices) == 1:
comp = ret.choices[0].message.content
else:
comp = [c.message.content for c in ret.choices]
else: else:
ret = client.completions.create(prompt=prompt, **kwargs) ret = client.completions.create(prompt=prompt, **kwargs)
if isinstance(prompt, (list, tuple)): if isinstance(prompt, (list, tuple)):
comp = [c.text for c in ret.choices] comp = [c.text for c in ret.choices]
else: else:
comp = ret.choices[0].text comp = ret.choices[0].text
if len(ret.choices) > 1:
comp = [c.text for c in ret.choices]
token_usage.prompt_tokens += ret.usage.prompt_tokens token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens token_usage.completion_tokens += ret.usage.completion_tokens
......
...@@ -566,13 +566,13 @@ class StreamExecutor: ...@@ -566,13 +566,13 @@ class StreamExecutor:
def _execute_gen(self, expr: SglGen): def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params) sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name name = expr.name
if not self.stream: if not self.stream:
if self.num_api_spec_tokens is None: if self.num_api_spec_tokens is None:
comp, meta_info = self.backend.generate( comp, meta_info = self.backend.generate(
self, self,
sampling_params=sampling_params, sampling_params=sampling_params,
) )
else: else:
if self.backend.is_chat_model: if self.backend.is_chat_model:
# Speculative execution on models with only chat interface. # Speculative execution on models with only chat interface.
...@@ -587,8 +587,11 @@ class StreamExecutor: ...@@ -587,8 +587,11 @@ class StreamExecutor:
else: # Speculative execution on models with completion interface else: # Speculative execution on models with completion interface
comp, meta_info = self._spec_gen(sampling_params) comp, meta_info = self._spec_gen(sampling_params)
if isinstance(comp, list):
self.text_ += comp self.text_ += comp[0]
else:
assert isinstance(comp, str)
self.text_ += comp
self.variables[name] = comp self.variables[name] = comp
self.meta_info[name] = meta_info self.meta_info[name] = meta_info
...@@ -747,6 +750,7 @@ class StreamExecutor: ...@@ -747,6 +750,7 @@ class StreamExecutor:
for item in [ for item in [
"max_new_tokens", "max_new_tokens",
"min_new_tokens", "min_new_tokens",
"n",
"stop", "stop",
"stop_token_ids", "stop_token_ids",
"temperature", "temperature",
......
...@@ -18,6 +18,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg ...@@ -18,6 +18,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
class SglSamplingParams: class SglSamplingParams:
max_new_tokens: int = 128 max_new_tokens: int = 128
min_new_tokens: int = 0 min_new_tokens: int = 0
n: int = 1
stop: Union[str, List[str]] = () stop: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = () stop_token_ids: Optional[List[int]] = ()
temperature: float = 1.0 temperature: float = 1.0
...@@ -41,6 +42,7 @@ class SglSamplingParams: ...@@ -41,6 +42,7 @@ class SglSamplingParams:
return SglSamplingParams( return SglSamplingParams(
self.max_new_tokens, self.max_new_tokens,
self.min_new_tokens, self.min_new_tokens,
self.n,
self.stop, self.stop,
self.stop_token_ids, self.stop_token_ids,
self.temperature, self.temperature,
...@@ -64,6 +66,7 @@ class SglSamplingParams: ...@@ -64,6 +66,7 @@ class SglSamplingParams:
return { return {
"max_tokens": self.max_new_tokens, "max_tokens": self.max_new_tokens,
"max_completion_tokens": self.max_new_tokens, "max_completion_tokens": self.max_new_tokens,
"n": self.n,
"stop": self.stop or None, "stop": self.stop or None,
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p, "top_p": self.top_p,
...@@ -117,6 +120,7 @@ class SglSamplingParams: ...@@ -117,6 +120,7 @@ class SglSamplingParams:
return { return {
"max_new_tokens": self.max_new_tokens, "max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens, "min_new_tokens": self.min_new_tokens,
"n": self.n,
"stop": self.stop, "stop": self.stop,
"stop_token_ids": self.stop_token_ids, "stop_token_ids": self.stop_token_ids,
"temperature": self.temperature, "temperature": self.temperature,
...@@ -154,6 +158,7 @@ class SglFunction: ...@@ -154,6 +158,7 @@ class SglFunction:
self, self,
*args, *args,
max_new_tokens: int = 128, max_new_tokens: int = 128,
n: int = 1,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0, temperature: float = 1.0,
...@@ -182,6 +187,7 @@ class SglFunction: ...@@ -182,6 +187,7 @@ class SglFunction:
default_sampling_para = SglSamplingParams( default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
n=n,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
temperature=temperature, temperature=temperature,
...@@ -212,6 +218,7 @@ class SglFunction: ...@@ -212,6 +218,7 @@ class SglFunction:
batch_kwargs, batch_kwargs,
*, *,
max_new_tokens: int = 128, max_new_tokens: int = 128,
n: int = 1,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0, temperature: float = 1.0,
...@@ -257,6 +264,7 @@ class SglFunction: ...@@ -257,6 +264,7 @@ class SglFunction:
default_sampling_para = SglSamplingParams( default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
n=n,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
temperature=temperature, temperature=temperature,
...@@ -440,6 +448,7 @@ class SglGen(SglExpr): ...@@ -440,6 +448,7 @@ class SglGen(SglExpr):
name: Optional[str] = None, name: Optional[str] = None,
max_new_tokens: Optional[int] = None, max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None, min_new_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
...@@ -463,6 +472,7 @@ class SglGen(SglExpr): ...@@ -463,6 +472,7 @@ class SglGen(SglExpr):
self.sampling_params = SglSamplingParams( self.sampling_params = SglSamplingParams(
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens, min_new_tokens=min_new_tokens,
n=n,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
temperature=temperature, temperature=temperature,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment