"vscode:/vscode.git/clone" did not exist on "b002f8f99fb0918a3d04629bfdb75df5e30cf1b6"
Unverified Commit 6cc30955 authored by Chuyue Sun's avatar Chuyue Sun Committed by GitHub
Browse files

Add support for OpenAI API o1 model (#3363)


Co-authored-by: default avatarShan Yu <shanyu1@g.ucla.edu>
parent 31eec35b
"""
Usage:
export OPENAI_API_KEY=sk-******
python3 openai_example_chat.py
"""
import sglang as sgl
@sgl.function
def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1", max_tokens=100))
s += sgl.user(question_2)
s += sgl.assistant(sgl.gen("answer_2"))
def single():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- answer_1 --\n", state["answer_1"])
def batch():
states = multi_turn_question.run_batch(
[
{
"question_1": "What is the capital of the United States?",
"question_2": "List two local attractions.",
},
{
"question_1": "What is the capital of France?",
"question_2": "What is the population of this city?",
},
]
)
for s in states:
print(s.messages())
if __name__ == "__main__":
sgl.set_default_backend(sgl.OpenAI("o1"))
# Run a single request
print("\n========== single ==========\n")
single()
# Run a batch of requests
print("\n========== batch ==========\n")
batch()
...@@ -161,6 +161,10 @@ class OpenAI(BaseBackend): ...@@ -161,6 +161,10 @@ class OpenAI(BaseBackend):
prompt = s.text_ prompt = s.text_
kwargs = sampling_params.to_openai_kwargs() kwargs = sampling_params.to_openai_kwargs()
if self.model_name.startswith("o1") or self.model_name.startswith("o3"):
kwargs.pop("max_tokens", None)
else:
kwargs.pop("max_completion_tokens", None)
comp = openai_completion( comp = openai_completion(
client=self.client, client=self.client,
token_usage=self.token_usage, token_usage=self.token_usage,
...@@ -175,6 +179,7 @@ class OpenAI(BaseBackend): ...@@ -175,6 +179,7 @@ class OpenAI(BaseBackend):
), "constrained type not supported on chat model" ), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs() kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop") kwargs.pop("stop")
comp = openai_completion( comp = openai_completion(
client=self.client, client=self.client,
token_usage=self.token_usage, token_usage=self.token_usage,
......
...@@ -63,6 +63,7 @@ class SglSamplingParams: ...@@ -63,6 +63,7 @@ class SglSamplingParams:
warnings.warn("Regular expression is not supported in the OpenAI backend.") warnings.warn("Regular expression is not supported in the OpenAI backend.")
return { return {
"max_tokens": self.max_new_tokens, "max_tokens": self.max_new_tokens,
"max_completion_tokens": self.max_new_tokens,
"stop": self.stop or None, "stop": self.stop or None,
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p, "top_p": self.top_p,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment