"vscode:/vscode.git/clone" did not exist on "d0d3e24ec19daca42129afb89f1031d7e7c9995a"
Unverified Commit 23950056 authored by parasol-aser's avatar parasol-aser Committed by GitHub
Browse files

support speculative execution for openai API (#48)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
parent 93414c82
from sglang import function, gen, set_default_backend, OpenAI
@function(api_num_spec_tokens=512)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\n"
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
state = gen_character_spec.run()
print("name:", state["name"])
print("birthday:", state["birthday"])
print("job:", state["job"])
...@@ -20,8 +20,16 @@ from sglang.lang.ir import ( ...@@ -20,8 +20,16 @@ from sglang.lang.ir import (
) )
def function(func: Callable): def function(
return SglFunction(func) func: Optional[Callable] = None, api_num_spec_tokens: Optional[int] = None
):
if func:
return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens)
def decorator(func):
return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens)
return decorator
def Runtime(*args, **kwargs): def Runtime(*args, **kwargs):
......
...@@ -51,10 +51,14 @@ def run_program( ...@@ -51,10 +51,14 @@ def run_program(
if hasattr(backend, "endpoint"): if hasattr(backend, "endpoint"):
backend = backend.endpoint backend = backend.endpoint
assert backend is not None, "Please specify a backend" assert backend is not None, "Please specify a backend"
func_kwargs.update(program.bind_arguments) func_kwargs.update(program.bind_arguments)
stream_executor = StreamExecutor( stream_executor = StreamExecutor(
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream backend,
func_kwargs,
default_sampling_para,
chat_template=None,
stream=stream,
api_num_spec_tokens=program.api_num_spec_tokens,
) )
state = ProgramState(stream_executor) state = ProgramState(stream_executor)
...@@ -175,6 +179,7 @@ class StreamExecutor: ...@@ -175,6 +179,7 @@ class StreamExecutor:
default_sampling_para, default_sampling_para,
chat_template, chat_template,
stream, stream,
api_num_spec_tokens=None,
use_thread=True, use_thread=True,
): ):
self.sid = uuid.uuid4().hex self.sid = uuid.uuid4().hex
...@@ -182,6 +187,7 @@ class StreamExecutor: ...@@ -182,6 +187,7 @@ class StreamExecutor:
self.arguments: Dict[str, Any] = arguments self.arguments: Dict[str, Any] = arguments
self.default_sampling_para = default_sampling_para self.default_sampling_para = default_sampling_para
self.stream = stream self.stream = stream
self.api_num_spec_tokens = api_num_spec_tokens
self.variables = {} # Dict[name: str -> value: str] self.variables = {} # Dict[name: str -> value: str]
self.variable_event = {} # Dict[name: str -> event: threading.Event] self.variable_event = {} # Dict[name: str -> event: threading.Event]
...@@ -191,6 +197,9 @@ class StreamExecutor: ...@@ -191,6 +197,9 @@ class StreamExecutor:
# For completion # For completion
self.text_ = "" # The full text self.text_ = "" # The full text
# For speculative execution
self.speculated_text = ""
# For chat # For chat
self.messages_ = [] # The messages in the OpenAI API format self.messages_ = [] # The messages in the OpenAI API format
self.chat_template = chat_template or self.backend.get_chat_template() self.chat_template = chat_template or self.backend.get_chat_template()
...@@ -341,6 +350,10 @@ class StreamExecutor: ...@@ -341,6 +350,10 @@ class StreamExecutor:
def _execute_fill(self, value: str): def _execute_fill(self, value: str):
value = str(value) value = str(value)
if self.speculated_text.startswith(value):
self.speculated_text = self.speculated_text[len(value) :]
else:
self.speculated_text = ""
self.text_ += value self.text_ += value
def _execute_image(self, expr: SglImage): def _execute_image(self, expr: SglImage):
...@@ -360,9 +373,61 @@ class StreamExecutor: ...@@ -360,9 +373,61 @@ class StreamExecutor:
name = expr.name name = expr.name
if not self.stream: if not self.stream:
comp, meta_info = self.backend.generate( if self.api_num_spec_tokens is not None:
self, sampling_params=sampling_params stop = sampling_params.stop
) max_new_tokens = sampling_params.max_new_tokens
meta_info = {}
def regen():
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.api_num_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
def find_stop():
if isinstance(stop, str):
return self.speculated_text.find(stop), len(stop)
elif isinstance(stop, (tuple, list)):
pos = -1
stop_len = 0
for stop_str in stop:
stop_pos = self.speculated_text.find(stop_str)
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
pos = stop_pos
stop_len = len(stop_str)
return pos, stop_len
else:
raise Exception("Wrong type of stop in sampling parameters.")
if stop is None:
if len(self.speculated_text) < max_new_tokens:
regen()
comp = self.speculated_text[:max_new_tokens]
self.speculated_text = self.speculated_text[max_new_tokens:]
elif isinstance(stop, (str, list, tuple)):
if self.speculated_text == "":
regen()
stop_pos, stop_len = find_stop()
if stop_pos == -1:
stop_pos, stop_len = (
min(
sampling_params.max_new_tokens,
len(self.speculated_text),
),
0,
)
comp = self.speculated_text[:stop_pos]
self.speculated_text = self.speculated_text[stop_pos:]
else:
raise ValueError("Wrong type of stop in sampling parameters.")
else:
comp, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
self.text_ += comp self.text_ += comp
self.variables[name] = comp self.variables[name] = comp
......
...@@ -95,8 +95,9 @@ class SglSamplingParams: ...@@ -95,8 +95,9 @@ class SglSamplingParams:
class SglFunction: class SglFunction:
def __init__(self, func, bind_arguments=None): def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None):
self.func = func self.func = func
self.api_num_spec_tokens = api_num_spec_tokens
self.bind_arguments = bind_arguments or {} self.bind_arguments = bind_arguments or {}
self.pin_prefix_rid = None self.pin_prefix_rid = None
......
...@@ -60,7 +60,9 @@ class DetokenizerManager: ...@@ -60,7 +60,9 @@ class DetokenizerManager:
if first_token.startswith("▁"): if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i] output_strs[i] = " " + output_strs[i]
output_strs[i] = recv_obj.output_and_fast_forward_strs[i] + output_strs[i] output_strs[i] = (
recv_obj.output_and_fast_forward_strs[i] + output_strs[i]
)
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
BatchStrOut( BatchStrOut(
......
...@@ -12,6 +12,7 @@ import rpyc ...@@ -12,6 +12,7 @@ import rpyc
import torch import torch
from rpyc.utils.classic import obtain from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer from rpyc.utils.server import ThreadedServer
from sglang.srt.constrained.fast_forward import FastForwardCache
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput
...@@ -21,7 +22,6 @@ from sglang.srt.managers.router.radix_cache import RadixCache ...@@ -21,7 +22,6 @@ from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.constrained.fast_forward import FastForwardCache
from sglang.srt.utils import ( from sglang.srt.utils import (
get_exception_traceback, get_exception_traceback,
get_int_token_logit_bias, get_int_token_logit_bias,
......
...@@ -200,6 +200,7 @@ class TokenizerManager: ...@@ -200,6 +200,7 @@ class TokenizerManager:
) )
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
rid=rid, rid=rid,
input_text=obj.text[i],
input_ids=input_ids, input_ids=input_ids,
pixel_values=pixel_values, pixel_values=pixel_values,
image_hash=image_hash, image_hash=image_hash,
......
from sglang import OpenAI, function, gen, set_default_backend
@function()
def gen_character_default(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
@function(api_num_spec_tokens=512)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
@function(api_num_spec_tokens=512)
def gen_character_no_stop(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name") + "\nBirthday:" + gen("birthday")
s += "\nJob:" + gen("job") + "\nWelcome.\n"
@function(api_num_spec_tokens=512)
def gen_character_multi_stop(s):
s += "Construct a character within the following format:\n"
s += (
"Name: Steve Jobs.###Birthday: February 24, 1955.###Job: Apple CEO.\nWelcome.\n"
)
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop=["\n", "###"])
s += "###Birthday:" + gen("birthday", stop=["\n", "###"])
s += "###Job:" + gen("job", stop=["\n", "###"]) + "\nWelcome.\n"
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
state = gen_character_default.run()
print(state.text())
print("=" * 60)
state = gen_character_no_stop.run()
print("name###", state["name"])
print("birthday###:", state["birthday"])
print("job###", state["job"])
print("=" * 60)
state = gen_character_multi_stop.run()
print(state.text())
print("=" * 60)
state = gen_character_spec.run()
print(state.text())
print("name###", state["name"])
print("birthday###", state["birthday"])
print("job###", state["job"])
import argparse import argparse
from enum import Enum from enum import Enum
import sglang as sgl
from pydantic import BaseModel, constr from pydantic import BaseModel, constr
from sglang.srt.constrained.json_schema import build_regex_from_object from sglang.srt.constrained.json_schema import build_regex_from_object
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -9,6 +8,8 @@ from sglang.test.test_utils import ( ...@@ -9,6 +8,8 @@ from sglang.test.test_utils import (
select_sglang_backend, select_sglang_backend,
) )
import sglang as sgl
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
ip_fast_forward = ( ip_fast_forward = (
......
...@@ -2,13 +2,14 @@ import argparse ...@@ -2,13 +2,14 @@ import argparse
import random import random
import string import string
import sglang as sgl
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
) )
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
import sglang as sgl
TOKENIZER = None TOKENIZER = None
RANDOM_PREFILL_LEN = None RANDOM_PREFILL_LEN = None
RANDOM_DECODE_LEN = None RANDOM_DECODE_LEN = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment