"vscode:/vscode.git/clone" did not exist on "e97f3a327c2e51eae56cd6886c19475bcd6d6839"
Unverified Commit ec380dfd authored by LiviaSun's avatar LiviaSun Committed by GitHub
Browse files

openai chat speculative execution (#250)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
parent 5b647543
......@@ -14,10 +14,25 @@ def gen_character_spec(s):
s += "\nJob:" + gen("job", stop="\n") + "\n"
@function(api_num_spec_tokens=64)
def gen_character_spec_no_few_shot(s):
# s += "Construct a character with name, birthday, and job:\n"
s += "Construct a character:\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\n"
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
state = gen_character_spec.run()
print("name:", state["name"])
print("birthday:", state["birthday"])
print("job:", state["job"])
print("...name:", state["name"])
print("...birthday:", state["birthday"])
print("...job:", state["job"])
state = gen_character_spec_no_few_shot.run()
print("\n...name:", state["name"])
print("...birthday:", state["birthday"])
print("...job:", state["job"])
"""
Usage:
***Note: for speculative execution to work, user must put all "gen" in "assistant". Show in "assistant" the desired answer format. Each "gen" term should have a stop token. The stream mode is not supported in speculative execution.
E.g.
correct:
sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
incorrect:
s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n"))
s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n"))
s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n"))
export OPENAI_API_KEY=sk-******
python3 openaichat_speculative.py
"""
import sglang as sgl
from sglang import function, gen, set_default_backend, OpenAI
@function(api_num_spec_tokens=512)
def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:")
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
s += sgl.user("Please generate new Name, Birthday and Job.\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
@function(api_num_spec_tokens=512)
def gen_character_spec_no_few_shot(s):
s += sgl.user("Construct a character. For each field stop with a newline\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
@function
def gen_character_normal(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("What's the answer of 23 + 8?")
s += sgl.assistant(sgl.gen("answer", max_tokens=64))
@function(api_num_spec_tokens=1024)
def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Answer questions in the following format:")
s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n")
s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n")
s += sgl.user("Question 1: "+question_1+"\nQuestion 2: "+question_2)
s += sgl.assistant("Answer 1: "+sgl.gen("answer_1", stop="\n") + "\nAnswer 2: "+ sgl.gen("answer_2", stop="\n"))
def test_spec_single_turn():
state = gen_character_spec.run()
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- name:", state["name"])
print("\n-- birthday:", state["birthday"])
print("\n-- job:", state["job"])
def test_inaccurate_spec_single_turn():
state = gen_character_spec_no_few_shot.run()
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- name:", state["name"])
print("\n-- age:", state["age"])
print("\n-- job:", state["job"])
def test_normal_single_turn():
state = gen_character_normal.run()
for m in state.messages():
print(m["role"], ":", m["content"])
def test_spec_multi_turn():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions in the capital of the United States.",
)
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- answer_1 --\n", state["answer_1"])
print("\n-- answer_2 --\n", state["answer_2"])
def test_spec_multi_turn_stream():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
stream=True
)
for out in state.text_iter():
print(out, end="", flush=True)
if __name__ == "__main__":
set_default_backend(OpenAI("gpt-4-turbo"))
print("\n========== test spec single turn ==========\n")
# expect reasonable answer for each field
test_spec_single_turn()
print("\n========== test inaccurate spec single turn ==========\n")
# expect incomplete or unreasonable answers
test_inaccurate_spec_single_turn()
print("\n========== test normal single turn ==========\n")
# expect reasonable answer
test_normal_single_turn()
print("\n========== test spec multi turn ==========\n")
# expect answer with same format as in the few shot
test_spec_multi_turn()
print("\n========== test spec multi turn stream ==========\n")
# expect error in stream_executor: stream is not supported...
test_spec_multi_turn_stream()
......@@ -9,6 +9,7 @@ class BaseBackend:
def __init__(self) -> None:
self.support_concate_and_append = False
self.chat_template = get_chat_template("default")
self.api_num_spec_tokens = None
def get_model_name(self):
raise NotImplementedError()
......
import logging
import time
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
......@@ -80,7 +81,15 @@ class OpenAI(BaseBackend):
else:
self.is_chat_model = True
self.chat_begin_str = self.chat_template.role_prefix_and_suffix["assistant"][0]
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
self.spec_kwargs = {}
self.spec_format = []
self.spec_max_num_tries = 3
self.api_num_spec_tokens = None
def set_api_num_spec_tokens(self, num):
self.api_num_spec_tokens = num
def get_chat_template(self):
return self.chat_template
......@@ -89,15 +98,45 @@ class OpenAI(BaseBackend):
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
name=None,
):
if sampling_params.dtype is None:
if self.is_chat_model:
if not s.text_.endswith(self.chat_begin_str):
if self.api_num_spec_tokens is None:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
"This use case is not supported if api speculative execution is off. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant."
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
)
prompt = s.messages_
else:
# collect assistant answer format
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = self.api_num_spec_tokens
else:
assert (
self.spec_kwargs["max_tokens"] == self.api_num_spec_tokens
)
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
if key in ["stop"]:
continue
if key in ["max_tokens"]:
warnings.warn(
"The parameter max_tokens will be overwritten by speculated number of tokens."
)
continue
if key not in self.spec_kwargs:
self.spec_kwargs[key] = value
else:
assert (
value == self.spec_kwargs[key]
), "sampling parameters should be consistent if turn on api speculative execution."
self.spec_format.append(
{"text": "", "stop": params["stop"], "name": name}
)
return "", {}
else:
prompt = s.text_
......@@ -110,6 +149,9 @@ class OpenAI(BaseBackend):
**kwargs,
)
elif sampling_params.dtype in [str, "str", "string"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
......@@ -122,6 +164,9 @@ class OpenAI(BaseBackend):
)
comp = '"' + comp + '"'
elif sampling_params.dtype in [int, "int"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
......@@ -138,6 +183,62 @@ class OpenAI(BaseBackend):
return comp, {}
def spec_fill(self, value: str):
assert self.is_chat_model
self.spec_format.append({"text": value, "stop": None, "name": None})
def spec_pattern_match(self, comp):
for i, term in enumerate(self.spec_format):
text = term["text"]
if text != "":
if comp.startswith(text):
comp = comp[len(text) :]
else:
return False
else:
pos = comp.find(term["stop"])
if pos != -1:
term["text"] = comp[:pos]
comp = comp[pos:]
else:
if i == len(self.spec_format) - 1:
term["text"] = comp
else:
return False
return True
def role_end_generate(
self,
s: StreamExecutor,
):
if self.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
return
comp = ""
if not all(x["name"] is None for x in self.spec_format):
for i in range(self.spec_max_num_tries):
comp = openai_completion(
client=self.client,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.messages_,
**self.spec_kwargs,
)
if self.spec_pattern_match(comp):
break
for term in self.spec_format:
stop = term["stop"] if term["stop"] is not None else ""
s.text_ += term["text"]
name = term["name"]
if name is not None:
s.variables[name] = term["text"]
s.meta_info[name] = {}
s.variable_event[name].set()
self.spec_kwargs = {}
self.spec_format = []
def generate_stream(
self,
s: StreamExecutor,
......@@ -145,7 +246,7 @@ class OpenAI(BaseBackend):
):
if sampling_params.dtype is None:
if self.is_chat_model:
if not s.text_.endswith(self.chat_begin_str):
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
......
......@@ -6,6 +6,7 @@ import multiprocessing
import queue
import threading
import uuid
import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
......@@ -185,7 +186,6 @@ class StreamExecutor:
self.arguments: Dict[str, Any] = arguments
self.default_sampling_para = default_sampling_para
self.stream = stream
self.api_num_spec_tokens = api_num_spec_tokens
self.variables = {} # Dict[name: str -> value: str]
self.variable_event = {} # Dict[name: str -> event: threading.Event]
......@@ -197,6 +197,9 @@ class StreamExecutor:
self.text_ = "" # The full text
# For speculative execution
from sglang.backend.openai import OpenAI
if isinstance(backend, OpenAI):
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
self.speculated_text = ""
# For chat
......@@ -322,7 +325,7 @@ class StreamExecutor:
try:
self._execute(expr)
except Exception as e:
# print(f"Error in stream_executor: {get_exception_traceback()}")
warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
error = e
break
self.queue.task_done()
......@@ -391,12 +394,23 @@ class StreamExecutor:
else:
raise ValueError(f"Unknown type: {type(other)}")
def _execute_fill(self, value: str):
def _execute_fill(self, value: str, prefix=False):
value = str(value)
if (
self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.backend.is_chat_model
and not prefix
):
self.backend.spec_fill(value)
return
if self.speculated_text.startswith(value):
self.speculated_text = self.speculated_text[len(value) :]
else:
self.speculated_text = ""
self.text_ += value
def _execute_image(self, expr: SglImage):
......@@ -426,14 +440,29 @@ class StreamExecutor:
name = expr.name
if not self.stream:
if self.api_num_spec_tokens is not None:
if self.backend.api_num_spec_tokens is None:
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
)
elif self.backend.is_chat_model:
# spec on model with only chat interface
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
name=name,
)
return
else: # spec on model with completion
stop = sampling_params.stop
max_new_tokens = sampling_params.max_new_tokens
meta_info = {}
def regen():
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.api_num_spec_tokens
sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
......@@ -442,16 +471,14 @@ class StreamExecutor:
def find_stop():
if isinstance(stop, str):
return self.speculated_text.find(stop), len(stop)
return self.speculated_text.find(stop)
elif isinstance(stop, (tuple, list)):
pos = -1
stop_len = 0
for stop_str in stop:
stop_pos = self.speculated_text.find(stop_str)
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
pos = stop_pos
stop_len = len(stop_str)
return pos, stop_len
return pos
else:
raise Exception("Wrong type of stop in sampling parameters.")
......@@ -463,23 +490,16 @@ class StreamExecutor:
elif isinstance(stop, (str, list, tuple)):
if self.speculated_text == "":
regen()
stop_pos, stop_len = find_stop()
stop_pos = find_stop()
if stop_pos == -1:
stop_pos, stop_len = (
min(
stop_pos = min(
sampling_params.max_new_tokens,
len(self.speculated_text),
),
0,
)
comp = self.speculated_text[:stop_pos]
self.speculated_text = self.speculated_text[stop_pos:]
else:
raise ValueError("Wrong type of stop in sampling parameters.")
else:
comp, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
self.text_ += comp
......@@ -487,6 +507,9 @@ class StreamExecutor:
self.meta_info[name] = meta_info
self.variable_event[name].set()
else:
assert (
self.backend.api_num_spec_tokens is None
), "stream is not supported with api speculative execution"
generator = self.backend.generate_stream(
self, sampling_params=sampling_params
)
......@@ -542,10 +565,18 @@ class StreamExecutor:
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
self._execute_fill(prefix)
self._execute_fill(prefix, prefix=True)
self.cur_role_begin_pos = len(self.text_)
def _execute_role_end(self, expr: SglRoleEnd):
if (
self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.backend.is_chat_model
):
self.backend.role_end_generate(self)
self.cur_role = None
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
......@@ -572,8 +603,6 @@ class StreamExecutor:
# OpenAI chat API format
self.messages_.append({"role": expr.role, "content": new_text})
self.cur_role = None
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
self.variables[expr.name] = int(len(self.text_))
......
......@@ -31,8 +31,9 @@ class GenerateReqInput:
def post_init(self):
if ((self.text is None and self.input_ids is None) or
(self.text is not None and self.input_ids is not None)):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
):
raise ValueError("Either text or input_ids should be provided.")
if self.text is not None:
......
......@@ -38,7 +38,6 @@ from sglang.srt.utils import (
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
......
......@@ -341,7 +341,6 @@ class TokenizerManager:
return top_logprobs
global global_processor
......
......@@ -9,8 +9,8 @@ import os
import sys
import threading
import time
from typing import List, Optional, Union
from http import HTTPStatus
from typing import List, Optional, Union
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
......@@ -45,7 +45,6 @@ from sglang.srt.utils import (
)
from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......@@ -84,6 +83,7 @@ async def flush_cache():
async def generate_request(obj: GenerateReqInput, request: Request):
if obj.stream:
async def stream_results():
try:
async for out in tokenizer_manager.generate_request(obj, request):
......@@ -99,8 +99,10 @@ async def generate_request(obj: GenerateReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return JSONResponse({"error": {"message": str(e)}},
status_code=HTTPStatus.BAD_REQUEST)
return JSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)
......
......@@ -19,7 +19,6 @@ from packaging import version as pkg_version
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
......@@ -157,7 +156,9 @@ def allocate_init_ports(
cur_port += 1
if port and ret_ports[0] != port:
logger.warn(f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead.")
logger.warn(
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
)
return ret_ports[0], ret_ports[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment