Unverified Commit 3e684be7 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Fix openai speculative execution (#456)

parent ec380dfd
""" """
Usage: Usage:
***Note: for speculative execution to work, user must put all "gen" in "assistant". Show in "assistant" the desired answer format. Each "gen" term should have a stop token. The stream mode is not supported in speculative execution. ***Note: for speculative execution to work, user must put all "gen" in "assistant".
Show in "assistant" the desired answer format. Each "gen" term should have a stop token.
The stream mode is not supported in speculative execution.
E.g. E.g.
correct: correct:
sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
incorrect: incorrect:
s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n")) s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n"))
s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n")) s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n"))
s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n")) s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n"))
export OPENAI_API_KEY=sk-****** export OPENAI_API_KEY=sk-******
python3 openaichat_speculative.py python3 openai_chat_speculative.py
""" """
import sglang as sgl import sglang as sgl
from sglang import function, gen, set_default_backend, OpenAI from sglang import function, set_default_backend, OpenAI
@function(api_num_spec_tokens=512) @function(api_num_spec_tokens=256)
def gen_character_spec(s): def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.") s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:") s += sgl.user("Construct a character within the following format:")
...@@ -25,7 +28,7 @@ def gen_character_spec(s): ...@@ -25,7 +28,7 @@ def gen_character_spec(s):
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
@function(api_num_spec_tokens=512) @function(api_num_spec_tokens=256)
def gen_character_spec_no_few_shot(s): def gen_character_spec_no_few_shot(s):
s += sgl.user("Construct a character. For each field stop with a newline\n") s += sgl.user("Construct a character. For each field stop with a newline\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
...@@ -44,18 +47,21 @@ def multi_turn_question(s, question_1, question_2): ...@@ -44,18 +47,21 @@ def multi_turn_question(s, question_1, question_2):
s += sgl.user("Answer questions in the following format:") s += sgl.user("Answer questions in the following format:")
s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n") s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n")
s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n") s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n")
s += sgl.user("Question 1: "+question_1+"\nQuestion 2: "+question_2) s += sgl.user("Question 1: " + question_1+"\nQuestion 2: " + question_2)
s += sgl.assistant("Answer 1: "+sgl.gen("answer_1", stop="\n") + "\nAnswer 2: "+ sgl.gen("answer_2", stop="\n")) s += sgl.assistant("Answer 1: " + sgl.gen("answer_1", stop="\n") + "\nAnswer 2: " + sgl.gen("answer_2", stop="\n"))
def test_spec_single_turn(): def test_spec_single_turn():
backend.token_usage.reset()
state = gen_character_spec.run() state = gen_character_spec.run()
for m in state.messages(): for m in state.messages():
print(m["role"], ":", m["content"]) print(m["role"], ":", m["content"])
print("\n-- name:", state["name"]) print("\n-- name:", state["name"])
print("\n-- birthday:", state["birthday"]) print("-- birthday:", state["birthday"])
print("\n-- job:", state["job"]) print("-- job:", state["job"])
print(backend.token_usage)
def test_inaccurate_spec_single_turn(): def test_inaccurate_spec_single_turn():
...@@ -99,7 +105,8 @@ def test_spec_multi_turn_stream(): ...@@ -99,7 +105,8 @@ def test_spec_multi_turn_stream():
if __name__ == "__main__": if __name__ == "__main__":
set_default_backend(OpenAI("gpt-4-turbo")) backend = OpenAI("gpt-4-turbo")
set_default_backend(backend)
print("\n========== test spec single turn ==========\n") print("\n========== test spec single turn ==========\n")
# expect reasonable answer for each field # expect reasonable answer for each field
...@@ -119,5 +126,4 @@ if __name__ == "__main__": ...@@ -119,5 +126,4 @@ if __name__ == "__main__":
print("\n========== test spec multi turn stream ==========\n") print("\n========== test spec multi turn stream ==========\n")
# expect error in stream_executor: stream is not supported... # expect error in stream_executor: stream is not supported...
test_spec_multi_turn_stream() test_spec_multi_turn_stream()
\ No newline at end of file
...@@ -5,7 +5,7 @@ python3 openai_speculative.py ...@@ -5,7 +5,7 @@ python3 openai_speculative.py
from sglang import function, gen, set_default_backend, OpenAI from sglang import function, gen, set_default_backend, OpenAI
@function(api_num_spec_tokens=512) @function(api_num_spec_tokens=64)
def gen_character_spec(s): def gen_character_spec(s):
s += "Construct a character within the following format:\n" s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
...@@ -14,6 +14,15 @@ def gen_character_spec(s): ...@@ -14,6 +14,15 @@ def gen_character_spec(s):
s += "\nJob:" + gen("job", stop="\n") + "\n" s += "\nJob:" + gen("job", stop="\n") + "\n"
@function
def gen_character_no_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\n"
@function(api_num_spec_tokens=64) @function(api_num_spec_tokens=64)
def gen_character_spec_no_few_shot(s): def gen_character_spec_no_few_shot(s):
# s += "Construct a character with name, birthday, and job:\n" # s += "Construct a character with name, birthday, and job:\n"
...@@ -22,17 +31,19 @@ def gen_character_spec_no_few_shot(s): ...@@ -22,17 +31,19 @@ def gen_character_spec_no_few_shot(s):
s += "\nJob:" + gen("job", stop="\n") + "\n" s += "\nJob:" + gen("job", stop="\n") + "\n"
set_default_backend(OpenAI("gpt-3.5-turbo-instruct")) if __name__ == "__main__":
backend = OpenAI("gpt-3.5-turbo-instruct")
state = gen_character_spec.run() set_default_backend(backend)
print("...name:", state["name"]) for function in [gen_character_spec, gen_character_no_spec, gen_character_spec_no_few_shot]:
print("...birthday:", state["birthday"]) backend.token_usage.reset()
print("...job:", state["job"])
state = gen_character_spec_no_few_shot.run() print(f"function: {function.func.__name__}")
print("\n...name:", state["name"]) state = function.run()
print("...birthday:", state["birthday"])
print("...job:", state["job"])
print("...name:", state["name"])
print("...birthday:", state["birthday"])
print("...job:", state["job"])
print(backend.token_usage)
print()
\ No newline at end of file
...@@ -9,7 +9,6 @@ class BaseBackend: ...@@ -9,7 +9,6 @@ class BaseBackend:
def __init__(self) -> None: def __init__(self) -> None:
self.support_concate_and_append = False self.support_concate_and_append = False
self.chat_template = get_chat_template("default") self.chat_template = get_chat_template("default")
self.api_num_spec_tokens = None
def get_model_name(self): def get_model_name(self):
raise NotImplementedError() raise NotImplementedError()
......
import logging import logging
import time import time
import warnings import warnings
import dataclasses
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -42,6 +43,15 @@ INSTRUCT_MODEL_NAMES = [ ...@@ -42,6 +43,15 @@ INSTRUCT_MODEL_NAMES = [
] ]
@dataclasses.dataclass
class TokenUsage:
prompt_tokens: int
completion_tokens: int
def reset(self):
self.prompt_tokens = self.completion_tokens = 0
class OpenAI(BaseBackend): class OpenAI(BaseBackend):
def __init__( def __init__(
self, self,
...@@ -83,66 +93,73 @@ class OpenAI(BaseBackend): ...@@ -83,66 +93,73 @@ class OpenAI(BaseBackend):
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0] self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
# Usage
self.token_usage = TokenUsage(0, 0)
# API speculative execution
# TODO(ying): This does not support multi-threading (run_batch)
self.spec_kwargs = {} self.spec_kwargs = {}
self.spec_format = [] self.spec_format = []
self.spec_max_num_tries = 3 self.spec_max_num_tries = 3
self.api_num_spec_tokens = None
def set_api_num_spec_tokens(self, num):
self.api_num_spec_tokens = num
def get_chat_template(self): def get_chat_template(self):
return self.chat_template return self.chat_template
def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
api_num_spec_tokens: int, spec_var_name: str):
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = api_num_spec_tokens
else:
assert (
self.spec_kwargs["max_tokens"] == api_num_spec_tokens
)
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
if key in ["stop"]:
continue
if key in ["max_tokens"]:
warnings.warn(
"The parameter max_tokens will be overwritten by speculated number of tokens."
)
continue
if key not in self.spec_kwargs:
self.spec_kwargs[key] = value
else:
assert (
value == self.spec_kwargs[key]
), "sampling parameters should be consistent if turn on api speculative execution."
self.spec_format.append(
{"text": "", "stop": params["stop"], "name": spec_var_name}
)
return "", {}
def generate( def generate(
self, self,
s: StreamExecutor, s: StreamExecutor,
sampling_params: SglSamplingParams, sampling_params: SglSamplingParams,
name=None, spec_var_name: str = None,
): ):
if sampling_params.dtype is None: if sampling_params.dtype is None:
if self.is_chat_model: if self.is_chat_model:
if self.api_num_spec_tokens is None: if s.api_num_spec_tokens is None:
if not s.text_.endswith(self.chat_prefix): if not s.text_.endswith(self.chat_prefix):
raise RuntimeError( raise RuntimeError(
"This use case is not supported if api speculative execution is off. " "This use case is not supported if api speculative execution is off. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant." "For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)." "Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
) )
prompt = s.messages_ prompt = s.messages_
else: else:
# collect assistant answer format return self._prepare_spec_execution(sampling_params,
if "max_tokens" not in self.spec_kwargs: s.api_num_spec_tokens, spec_var_name)
self.spec_kwargs["max_tokens"] = self.api_num_spec_tokens
else:
assert (
self.spec_kwargs["max_tokens"] == self.api_num_spec_tokens
)
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
if key in ["stop"]:
continue
if key in ["max_tokens"]:
warnings.warn(
"The parameter max_tokens will be overwritten by speculated number of tokens."
)
continue
if key not in self.spec_kwargs:
self.spec_kwargs[key] = value
else:
assert (
value == self.spec_kwargs[key]
), "sampling parameters should be consistent if turn on api speculative execution."
self.spec_format.append(
{"text": "", "stop": params["stop"], "name": name}
)
return "", {}
else: else:
prompt = s.text_ prompt = s.text_
kwargs = sampling_params.to_openai_kwargs() kwargs = sampling_params.to_openai_kwargs()
comp = openai_completion( comp = openai_completion(
client=self.client, client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model, is_chat=self.is_chat_model,
model=self.model_name, model=self.model_name,
prompt=prompt, prompt=prompt,
...@@ -156,6 +173,7 @@ class OpenAI(BaseBackend): ...@@ -156,6 +173,7 @@ class OpenAI(BaseBackend):
kwargs.pop("stop") kwargs.pop("stop")
comp = openai_completion( comp = openai_completion(
client=self.client, client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model, is_chat=self.is_chat_model,
model=self.model_name, model=self.model_name,
prompt=s.text_ + '"', prompt=s.text_ + '"',
...@@ -171,6 +189,7 @@ class OpenAI(BaseBackend): ...@@ -171,6 +189,7 @@ class OpenAI(BaseBackend):
kwargs.pop("stop") kwargs.pop("stop")
comp = openai_completion( comp = openai_completion(
client=self.client, client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model, is_chat=self.is_chat_model,
model=self.model_name, model=self.model_name,
prompt=s.text_, prompt=s.text_,
...@@ -211,14 +230,16 @@ class OpenAI(BaseBackend): ...@@ -211,14 +230,16 @@ class OpenAI(BaseBackend):
self, self,
s: StreamExecutor, s: StreamExecutor,
): ):
if self.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix): if s.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
return return
comp = "" comp = ""
if not all(x["name"] is None for x in self.spec_format): if not all(x["name"] is None for x in self.spec_format):
# TODO(ying): throw errors or warnings
for i in range(self.spec_max_num_tries): for i in range(self.spec_max_num_tries):
comp = openai_completion( comp = openai_completion(
client=self.client, client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model, is_chat=self.is_chat_model,
model=self.model_name, model=self.model_name,
prompt=s.messages_, prompt=s.messages_,
...@@ -228,7 +249,6 @@ class OpenAI(BaseBackend): ...@@ -228,7 +249,6 @@ class OpenAI(BaseBackend):
break break
for term in self.spec_format: for term in self.spec_format:
stop = term["stop"] if term["stop"] is not None else ""
s.text_ += term["text"] s.text_ += term["text"]
name = term["name"] name = term["name"]
if name is not None: if name is not None:
...@@ -258,6 +278,7 @@ class OpenAI(BaseBackend): ...@@ -258,6 +278,7 @@ class OpenAI(BaseBackend):
kwargs = sampling_params.to_openai_kwargs() kwargs = sampling_params.to_openai_kwargs()
generator = openai_completion_stream( generator = openai_completion_stream(
client=self.client, client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model, is_chat=self.is_chat_model,
model=self.model_name, model=self.model_name,
prompt=prompt, prompt=prompt,
...@@ -303,6 +324,8 @@ class OpenAI(BaseBackend): ...@@ -303,6 +324,8 @@ class OpenAI(BaseBackend):
) )
ret_str = ret.choices[0].text ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0] ret_token = self.tokenizer.encode(ret_str)[0]
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
self.token_usage.completion_tokens= ret.usage.completion_tokens
# TODO: # TODO:
# 1. return logits as the scores # 1. return logits as the scores
...@@ -332,7 +355,7 @@ class OpenAI(BaseBackend): ...@@ -332,7 +355,7 @@ class OpenAI(BaseBackend):
return decision, scores, None, None return decision, scores, None, None
def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs): def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
for attempt in range(retries): for attempt in range(retries):
try: try:
if is_chat: if is_chat:
...@@ -346,6 +369,9 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs): ...@@ -346,6 +369,9 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
comp = [c.text for c in ret.choices] comp = [c.text for c in ret.choices]
else: else:
comp = ret.choices[0].text comp = ret.choices[0].text
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e: except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...") logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
...@@ -359,16 +385,19 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs): ...@@ -359,16 +385,19 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
return comp return comp
def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwargs): def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
for attempt in range(retries): for attempt in range(retries):
try: try:
if is_chat: if is_chat:
if "stop" in kwargs and kwargs["stop"] is None: if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop") kwargs.pop("stop")
generator = client.chat.completions.create( generator = client.chat.completions.create(
messages=prompt, stream=True, **kwargs messages=prompt, stream=True, stream_options={"include_usage": True},
**kwargs
) )
for ret in generator: for ret in generator:
if len(ret.choices) == 0:
continue
try: try:
content = ret.choices[0].delta.content content = ret.choices[0].delta.content
except IndexError: except IndexError:
...@@ -376,11 +405,17 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa ...@@ -376,11 +405,17 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
yield content or "", {} yield content or "", {}
else: else:
generator = client.completions.create( generator = client.completions.create(
prompt=prompt, stream=True, **kwargs prompt=prompt, stream=True, stream_options={"include_usage": True},
**kwargs
) )
for ret in generator: for ret in generator:
if len(ret.choices) == 0:
continue
content = ret.choices[0].text content = ret.choices[0].text
yield content or "", {} yield content or "", {}
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e: except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...") logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
......
...@@ -196,12 +196,6 @@ class StreamExecutor: ...@@ -196,12 +196,6 @@ class StreamExecutor:
# For completion # For completion
self.text_ = "" # The full text self.text_ = "" # The full text
# For speculative execution
from sglang.backend.openai import OpenAI
if isinstance(backend, OpenAI):
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
self.speculated_text = ""
# For chat # For chat
self.messages_ = [] # The messages in the OpenAI API format self.messages_ = [] # The messages in the OpenAI API format
self.chat_template = chat_template or self.backend.get_chat_template() self.chat_template = chat_template or self.backend.get_chat_template()
...@@ -215,6 +209,10 @@ class StreamExecutor: ...@@ -215,6 +209,10 @@ class StreamExecutor:
# For fork/join # For fork/join
self.fork_start_text_pos = None self.fork_start_text_pos = None
# For speculative execution
self.api_num_spec_tokens = api_num_spec_tokens
self.speculated_text = ""
# Worker thread # Worker thread
self.use_thread = use_thread self.use_thread = use_thread
if self.use_thread: if self.use_thread:
...@@ -293,6 +291,8 @@ class StreamExecutor: ...@@ -293,6 +291,8 @@ class StreamExecutor:
exes[i].fork_start_text_pos = len(self.text_) exes[i].fork_start_text_pos = len(self.text_)
exes[i].images_ = list(self.images_) exes[i].images_ = list(self.images_)
# TODO(ying): handle API speculative execution
return exes return exes
def text(self): def text(self):
...@@ -399,7 +399,7 @@ class StreamExecutor: ...@@ -399,7 +399,7 @@ class StreamExecutor:
if ( if (
self.cur_role == "assistant" self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None and self.api_num_spec_tokens is not None
and self.backend.is_chat_model and self.backend.is_chat_model
and not prefix and not prefix
): ):
...@@ -435,71 +435,80 @@ class StreamExecutor: ...@@ -435,71 +435,80 @@ class StreamExecutor:
# if global_config.eager_fill_image: # if global_config.eager_fill_image:
# self.backend.fill_image(self) # self.backend.fill_image(self)
def _spec_gen(self, sampling_params):
stop = sampling_params.stop
max_new_tokens = sampling_params.max_new_tokens
meta_info = {}
def regen():
nonlocal meta_info
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.api_num_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
def find_stop():
if isinstance(stop, str):
return self.speculated_text.find(stop)
elif isinstance(stop, (tuple, list)):
pos = -1
for stop_str in stop:
stop_pos = self.speculated_text.find(stop_str)
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
pos = stop_pos
return pos
else:
raise Exception("Wrong type of stop in sampling parameters.")
if stop is None:
if len(self.speculated_text) < max_new_tokens:
regen()
comp = self.speculated_text[:max_new_tokens]
self.speculated_text = self.speculated_text[max_new_tokens:]
elif isinstance(stop, (str, list, tuple)):
if self.speculated_text == "":
regen()
stop_pos = find_stop()
if stop_pos == -1:
stop_pos = min(
sampling_params.max_new_tokens,
len(self.speculated_text),
)
comp = self.speculated_text[:stop_pos]
self.speculated_text = self.speculated_text[stop_pos:]
else:
raise ValueError("Wrong type of stop in sampling parameters.")
return comp, meta_info
def _execute_gen(self, expr: SglGen): def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params) sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name name = expr.name
if not self.stream: if not self.stream:
if self.backend.api_num_spec_tokens is None: if self.api_num_spec_tokens is None:
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
)
elif self.backend.is_chat_model:
# spec on model with only chat interface
comp, meta_info = self.backend.generate( comp, meta_info = self.backend.generate(
self, self,
sampling_params=sampling_params, sampling_params=sampling_params,
name=name,
) )
return else:
if self.backend.is_chat_model:
else: # spec on model with completion # Speculative execution on models with only chat interface.
stop = sampling_params.stop # Store the calls into a temporary list.
max_new_tokens = sampling_params.max_new_tokens # They will be lazily executed later.
meta_info = {} comp, meta_info = self.backend.generate(
self,
def regen(): sampling_params=sampling_params,
sampling_params.max_new_tokens = max( spec_var_name=name,
sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
self, sampling_params=sampling_params
) )
return
def find_stop(): else: # Speculative execution on models with completion interface
if isinstance(stop, str): comp, meta_info = self._spec_gen(sampling_params)
return self.speculated_text.find(stop)
elif isinstance(stop, (tuple, list)):
pos = -1
for stop_str in stop:
stop_pos = self.speculated_text.find(stop_str)
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
pos = stop_pos
return pos
else:
raise Exception("Wrong type of stop in sampling parameters.")
if stop is None:
if len(self.speculated_text) < max_new_tokens:
regen()
comp = self.speculated_text[:max_new_tokens]
self.speculated_text = self.speculated_text[max_new_tokens:]
elif isinstance(stop, (str, list, tuple)):
if self.speculated_text == "":
regen()
stop_pos = find_stop()
if stop_pos == -1:
stop_pos = min(
sampling_params.max_new_tokens,
len(self.speculated_text),
)
comp = self.speculated_text[:stop_pos]
self.speculated_text = self.speculated_text[stop_pos:]
else:
raise ValueError("Wrong type of stop in sampling parameters.")
self.text_ += comp self.text_ += comp
...@@ -508,7 +517,7 @@ class StreamExecutor: ...@@ -508,7 +517,7 @@ class StreamExecutor:
self.variable_event[name].set() self.variable_event[name].set()
else: else:
assert ( assert (
self.backend.api_num_spec_tokens is None self.api_num_spec_tokens is None
), "stream is not supported with api speculative execution" ), "stream is not supported with api speculative execution"
generator = self.backend.generate_stream( generator = self.backend.generate_stream(
self, sampling_params=sampling_params self, sampling_params=sampling_params
...@@ -571,9 +580,10 @@ class StreamExecutor: ...@@ -571,9 +580,10 @@ class StreamExecutor:
def _execute_role_end(self, expr: SglRoleEnd): def _execute_role_end(self, expr: SglRoleEnd):
if ( if (
self.cur_role == "assistant" self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.backend.is_chat_model and self.backend.is_chat_model
and self.api_num_spec_tokens is not None
): ):
# Execute the stored lazy generation calls
self.backend.role_end_generate(self) self.backend.role_end_generate(self)
self.cur_role = None self.cur_role = None
......
...@@ -304,6 +304,7 @@ def test_image_qa(): ...@@ -304,6 +304,7 @@ def test_image_qa():
temperature=0, temperature=0,
max_new_tokens=64, max_new_tokens=64,
) )
assert ( assert (
"taxi" in state.messages()[-1]["content"] "taxi" in state.messages()[-1]["content"]
or "car" in state.messages()[-1]["content"] or "car" in state.messages()[-1]["content"]
...@@ -349,3 +350,46 @@ def test_regex(): ...@@ -349,3 +350,46 @@ def test_regex():
state = regex_gen.run() state = regex_gen.run()
answer = state["answer"] answer = state["answer"]
assert re.match(regex, answer) assert re.match(regex, answer)
def test_completion_speculative():
@sgl.function(api_num_spec_tokens=64)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
@sgl.function
def gen_character_no_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
token_usage = sgl.global_config.default_backend.token_usage
token_usage.reset()
gen_character_spec().sync()
usage_with_spec = token_usage.prompt_tokens
token_usage.reset()
gen_character_no_spec().sync()
usage_with_no_spec = token_usage.prompt_tokens
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
def test_chat_completion_speculative():
@sgl.function(api_num_spec_tokens=256)
def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:")
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
s += sgl.user("Please generate new Name, Birthday and Job.\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
gen_character_spec().sync()
\ No newline at end of file
...@@ -14,6 +14,8 @@ from sglang.test.test_programs import ( ...@@ -14,6 +14,8 @@ from sglang.test.test_programs import (
test_select, test_select,
test_stream, test_stream,
test_tool_use, test_tool_use,
test_completion_speculative,
test_chat_completion_speculative
) )
...@@ -78,6 +80,14 @@ class TestOpenAIBackend(unittest.TestCase): ...@@ -78,6 +80,14 @@ class TestOpenAIBackend(unittest.TestCase):
set_default_backend(self.backend) set_default_backend(self.backend)
test_stream() test_stream()
def test_completion_speculative(self):
set_default_backend(self.backend)
test_completion_speculative()
def test_chat_completion_speculative(self):
set_default_backend(self.chat_backend)
test_chat_completion_speculative()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(warnings="ignore") unittest.main(warnings="ignore")
...@@ -87,4 +97,4 @@ if __name__ == "__main__": ...@@ -87,4 +97,4 @@ if __name__ == "__main__":
# global_config.verbosity = 2 # global_config.verbosity = 2
# t = TestOpenAIBackend() # t = TestOpenAIBackend()
# t.setUp() # t.setUp()
# t.test_few_shot_qa() # t.test_chat_completion_speculative()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment