Unverified Commit 3e684be7 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Fix openai speculative execution (#456)

parent ec380dfd
"""
Usage:
***Note: for speculative execution to work, user must put all "gen" in "assistant". Show in "assistant" the desired answer format. Each "gen" term should have a stop token. The stream mode is not supported in speculative execution.
***Note: for speculative execution to work, user must put all "gen" in "assistant".
Show in "assistant" the desired answer format. Each "gen" term should have a stop token.
The stream mode is not supported in speculative execution.
E.g.
correct:
sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
......@@ -10,13 +13,13 @@ incorrect:
s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n"))
export OPENAI_API_KEY=sk-******
python3 openaichat_speculative.py
python3 openai_chat_speculative.py
"""
import sglang as sgl
from sglang import function, gen, set_default_backend, OpenAI
from sglang import function, set_default_backend, OpenAI
@function(api_num_spec_tokens=512)
@function(api_num_spec_tokens=256)
def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:")
......@@ -25,7 +28,7 @@ def gen_character_spec(s):
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
@function(api_num_spec_tokens=512)
@function(api_num_spec_tokens=256)
def gen_character_spec_no_few_shot(s):
s += sgl.user("Construct a character. For each field stop with a newline\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
......@@ -44,18 +47,21 @@ def multi_turn_question(s, question_1, question_2):
s += sgl.user("Answer questions in the following format:")
s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n")
s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n")
s += sgl.user("Question 1: "+question_1+"\nQuestion 2: "+question_2)
s += sgl.assistant("Answer 1: "+sgl.gen("answer_1", stop="\n") + "\nAnswer 2: "+ sgl.gen("answer_2", stop="\n"))
s += sgl.user("Question 1: " + question_1+"\nQuestion 2: " + question_2)
s += sgl.assistant("Answer 1: " + sgl.gen("answer_1", stop="\n") + "\nAnswer 2: " + sgl.gen("answer_2", stop="\n"))
def test_spec_single_turn():
backend.token_usage.reset()
state = gen_character_spec.run()
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- name:", state["name"])
print("\n-- birthday:", state["birthday"])
print("\n-- job:", state["job"])
print("-- birthday:", state["birthday"])
print("-- job:", state["job"])
print(backend.token_usage)
def test_inaccurate_spec_single_turn():
......@@ -99,7 +105,8 @@ def test_spec_multi_turn_stream():
if __name__ == "__main__":
set_default_backend(OpenAI("gpt-4-turbo"))
backend = OpenAI("gpt-4-turbo")
set_default_backend(backend)
print("\n========== test spec single turn ==========\n")
# expect reasonable answer for each field
......@@ -120,4 +127,3 @@ if __name__ == "__main__":
print("\n========== test spec multi turn stream ==========\n")
# expect error in stream_executor: stream is not supported...
test_spec_multi_turn_stream()
\ No newline at end of file
......@@ -5,7 +5,7 @@ python3 openai_speculative.py
from sglang import function, gen, set_default_backend, OpenAI
@function(api_num_spec_tokens=512)
@function(api_num_spec_tokens=64)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
......@@ -14,6 +14,15 @@ def gen_character_spec(s):
s += "\nJob:" + gen("job", stop="\n") + "\n"
@function
def gen_character_no_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\n"
@function(api_num_spec_tokens=64)
def gen_character_spec_no_few_shot(s):
# s += "Construct a character with name, birthday, and job:\n"
......@@ -22,17 +31,19 @@ def gen_character_spec_no_few_shot(s):
s += "\nJob:" + gen("job", stop="\n") + "\n"
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
state = gen_character_spec.run()
if __name__ == "__main__":
backend = OpenAI("gpt-3.5-turbo-instruct")
set_default_backend(backend)
print("...name:", state["name"])
print("...birthday:", state["birthday"])
print("...job:", state["job"])
for function in [gen_character_spec, gen_character_no_spec, gen_character_spec_no_few_shot]:
backend.token_usage.reset()
state = gen_character_spec_no_few_shot.run()
print(f"function: {function.func.__name__}")
print("\n...name:", state["name"])
print("...birthday:", state["birthday"])
print("...job:", state["job"])
state = function.run()
print("...name:", state["name"])
print("...birthday:", state["birthday"])
print("...job:", state["job"])
print(backend.token_usage)
print()
\ No newline at end of file
......@@ -9,7 +9,6 @@ class BaseBackend:
def __init__(self) -> None:
self.support_concate_and_append = False
self.chat_template = get_chat_template("default")
self.api_num_spec_tokens = None
def get_model_name(self):
raise NotImplementedError()
......
import logging
import time
import warnings
import dataclasses
from typing import Callable, List, Optional, Union
import numpy as np
......@@ -42,6 +43,15 @@ INSTRUCT_MODEL_NAMES = [
]
@dataclasses.dataclass
class TokenUsage:
prompt_tokens: int
completion_tokens: int
def reset(self):
self.prompt_tokens = self.completion_tokens = 0
class OpenAI(BaseBackend):
def __init__(
self,
......@@ -83,41 +93,27 @@ class OpenAI(BaseBackend):
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
# Usage
self.token_usage = TokenUsage(0, 0)
# API speculative execution
# TODO(ying): This does not support multi-threading (run_batch)
self.spec_kwargs = {}
self.spec_format = []
self.spec_max_num_tries = 3
self.api_num_spec_tokens = None
def set_api_num_spec_tokens(self, num):
self.api_num_spec_tokens = num
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
name=None,
):
if sampling_params.dtype is None:
if self.is_chat_model:
if self.api_num_spec_tokens is None:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported if api speculative execution is off. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant."
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
)
prompt = s.messages_
else:
# collect assistant answer format
def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
api_num_spec_tokens: int, spec_var_name: str):
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = self.api_num_spec_tokens
self.spec_kwargs["max_tokens"] = api_num_spec_tokens
else:
assert (
self.spec_kwargs["max_tokens"] == self.api_num_spec_tokens
self.spec_kwargs["max_tokens"] == api_num_spec_tokens
)
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
if key in ["stop"]:
......@@ -134,15 +130,36 @@ class OpenAI(BaseBackend):
value == self.spec_kwargs[key]
), "sampling parameters should be consistent if turn on api speculative execution."
self.spec_format.append(
{"text": "", "stop": params["stop"], "name": name}
{"text": "", "stop": params["stop"], "name": spec_var_name}
)
return "", {}
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
spec_var_name: str = None,
):
if sampling_params.dtype is None:
if self.is_chat_model:
if s.api_num_spec_tokens is None:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported if api speculative execution is off. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
)
prompt = s.messages_
else:
return self._prepare_spec_execution(sampling_params,
s.api_num_spec_tokens, spec_var_name)
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
......@@ -156,6 +173,7 @@ class OpenAI(BaseBackend):
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_ + '"',
......@@ -171,6 +189,7 @@ class OpenAI(BaseBackend):
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_,
......@@ -211,14 +230,16 @@ class OpenAI(BaseBackend):
self,
s: StreamExecutor,
):
if self.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
if s.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
return
comp = ""
if not all(x["name"] is None for x in self.spec_format):
# TODO(ying): throw errors or warnings
for i in range(self.spec_max_num_tries):
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.messages_,
......@@ -228,7 +249,6 @@ class OpenAI(BaseBackend):
break
for term in self.spec_format:
stop = term["stop"] if term["stop"] is not None else ""
s.text_ += term["text"]
name = term["name"]
if name is not None:
......@@ -258,6 +278,7 @@ class OpenAI(BaseBackend):
kwargs = sampling_params.to_openai_kwargs()
generator = openai_completion_stream(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
......@@ -303,6 +324,8 @@ class OpenAI(BaseBackend):
)
ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0]
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
self.token_usage.completion_tokens= ret.usage.completion_tokens
# TODO:
# 1. return logits as the scores
......@@ -332,7 +355,7 @@ class OpenAI(BaseBackend):
return decision, scores, None, None
def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
for attempt in range(retries):
try:
if is_chat:
......@@ -346,6 +369,9 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
comp = [c.text for c in ret.choices]
else:
comp = ret.choices[0].text
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
......@@ -359,16 +385,19 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
return comp
def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwargs):
def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
generator = client.chat.completions.create(
messages=prompt, stream=True, **kwargs
messages=prompt, stream=True, stream_options={"include_usage": True},
**kwargs
)
for ret in generator:
if len(ret.choices) == 0:
continue
try:
content = ret.choices[0].delta.content
except IndexError:
......@@ -376,11 +405,17 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
yield content or "", {}
else:
generator = client.completions.create(
prompt=prompt, stream=True, **kwargs
prompt=prompt, stream=True, stream_options={"include_usage": True},
**kwargs
)
for ret in generator:
if len(ret.choices) == 0:
continue
content = ret.choices[0].text
yield content or "", {}
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
......
......@@ -196,12 +196,6 @@ class StreamExecutor:
# For completion
self.text_ = "" # The full text
# For speculative execution
from sglang.backend.openai import OpenAI
if isinstance(backend, OpenAI):
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
self.speculated_text = ""
# For chat
self.messages_ = [] # The messages in the OpenAI API format
self.chat_template = chat_template or self.backend.get_chat_template()
......@@ -215,6 +209,10 @@ class StreamExecutor:
# For fork/join
self.fork_start_text_pos = None
# For speculative execution
self.api_num_spec_tokens = api_num_spec_tokens
self.speculated_text = ""
# Worker thread
self.use_thread = use_thread
if self.use_thread:
......@@ -293,6 +291,8 @@ class StreamExecutor:
exes[i].fork_start_text_pos = len(self.text_)
exes[i].images_ = list(self.images_)
# TODO(ying): handle API speculative execution
return exes
def text(self):
......@@ -399,7 +399,7 @@ class StreamExecutor:
if (
self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.api_num_spec_tokens is not None
and self.backend.is_chat_model
and not prefix
):
......@@ -435,34 +435,16 @@ class StreamExecutor:
# if global_config.eager_fill_image:
# self.backend.fill_image(self)
def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name
if not self.stream:
if self.backend.api_num_spec_tokens is None:
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
)
elif self.backend.is_chat_model:
# spec on model with only chat interface
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
name=name,
)
return
else: # spec on model with completion
def _spec_gen(self, sampling_params):
stop = sampling_params.stop
max_new_tokens = sampling_params.max_new_tokens
meta_info = {}
def regen():
nonlocal meta_info
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
sampling_params.max_new_tokens, self.api_num_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
......@@ -501,6 +483,33 @@ class StreamExecutor:
else:
raise ValueError("Wrong type of stop in sampling parameters.")
return comp, meta_info
def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name
if not self.stream:
if self.api_num_spec_tokens is None:
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
)
else:
if self.backend.is_chat_model:
# Speculative execution on models with only chat interface.
# Store the calls into a temporary list.
# They will be lazily executed later.
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
spec_var_name=name,
)
return
else: # Speculative execution on models with completion interface
comp, meta_info = self._spec_gen(sampling_params)
self.text_ += comp
self.variables[name] = comp
......@@ -508,7 +517,7 @@ class StreamExecutor:
self.variable_event[name].set()
else:
assert (
self.backend.api_num_spec_tokens is None
self.api_num_spec_tokens is None
), "stream is not supported with api speculative execution"
generator = self.backend.generate_stream(
self, sampling_params=sampling_params
......@@ -571,9 +580,10 @@ class StreamExecutor:
def _execute_role_end(self, expr: SglRoleEnd):
if (
self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.backend.is_chat_model
and self.api_num_spec_tokens is not None
):
# Execute the stored lazy generation calls
self.backend.role_end_generate(self)
self.cur_role = None
......
......@@ -304,6 +304,7 @@ def test_image_qa():
temperature=0,
max_new_tokens=64,
)
assert (
"taxi" in state.messages()[-1]["content"]
or "car" in state.messages()[-1]["content"]
......@@ -349,3 +350,46 @@ def test_regex():
state = regex_gen.run()
answer = state["answer"]
assert re.match(regex, answer)
def test_completion_speculative():
@sgl.function(api_num_spec_tokens=64)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
@sgl.function
def gen_character_no_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
token_usage = sgl.global_config.default_backend.token_usage
token_usage.reset()
gen_character_spec().sync()
usage_with_spec = token_usage.prompt_tokens
token_usage.reset()
gen_character_no_spec().sync()
usage_with_no_spec = token_usage.prompt_tokens
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
def test_chat_completion_speculative():
@sgl.function(api_num_spec_tokens=256)
def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:")
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
s += sgl.user("Please generate new Name, Birthday and Job.\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
gen_character_spec().sync()
\ No newline at end of file
......@@ -14,6 +14,8 @@ from sglang.test.test_programs import (
test_select,
test_stream,
test_tool_use,
test_completion_speculative,
test_chat_completion_speculative
)
......@@ -78,6 +80,14 @@ class TestOpenAIBackend(unittest.TestCase):
set_default_backend(self.backend)
test_stream()
def test_completion_speculative(self):
set_default_backend(self.backend)
test_completion_speculative()
def test_chat_completion_speculative(self):
set_default_backend(self.chat_backend)
test_chat_completion_speculative()
if __name__ == "__main__":
unittest.main(warnings="ignore")
......@@ -87,4 +97,4 @@ if __name__ == "__main__":
# global_config.verbosity = 2
# t = TestOpenAIBackend()
# t.setUp()
# t.test_few_shot_qa()
# t.test_chat_completion_speculative()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment