Unverified Commit ced77c66 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename api_num_spec_tokens -> num_api_spec_tokens (#458)

parent 8dbdc018
......@@ -19,7 +19,7 @@ import sglang as sgl
from sglang import function, set_default_backend, OpenAI
@function(api_num_spec_tokens=256)
@function(num_api_spec_tokens=256)
def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:")
......@@ -28,7 +28,7 @@ def gen_character_spec(s):
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
@function(api_num_spec_tokens=256)
@function(num_api_spec_tokens=256)
def gen_character_spec_no_few_shot(s):
s += sgl.user("Construct a character. For each field stop with a newline\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
......@@ -41,7 +41,7 @@ def gen_character_normal(s):
s += sgl.assistant(sgl.gen("answer", max_tokens=64))
@function(api_num_spec_tokens=1024)
@function(num_api_spec_tokens=1024)
def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Answer questions in the following format:")
......
......@@ -5,7 +5,7 @@ python3 openai_speculative.py
from sglang import function, gen, set_default_backend, OpenAI
@function(api_num_spec_tokens=64)
@function(num_api_spec_tokens=64)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
......@@ -23,7 +23,7 @@ def gen_character_no_spec(s):
s += "\nJob:" + gen("job", stop="\n") + "\n"
@function(api_num_spec_tokens=64)
@function(num_api_spec_tokens=64)
def gen_character_spec_no_few_shot(s):
# s += "Construct a character with name, birthday, and job:\n"
s += "Construct a character:\n"
......
......@@ -20,13 +20,13 @@ from sglang.lang.ir import (
def function(
func: Optional[Callable] = None, api_num_spec_tokens: Optional[int] = None
func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
):
if func:
return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens)
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
def decorator(func):
return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens)
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
return decorator
......
......@@ -106,12 +106,12 @@ class OpenAI(BaseBackend):
return self.chat_template
def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
api_num_spec_tokens: int, spec_var_name: str):
num_api_spec_tokens: int, spec_var_name: str):
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = api_num_spec_tokens
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
else:
assert (
self.spec_kwargs["max_tokens"] == api_num_spec_tokens
self.spec_kwargs["max_tokens"] == num_api_spec_tokens
)
params = sampling_params.to_openai_kwargs()
......@@ -142,17 +142,17 @@ class OpenAI(BaseBackend):
):
if sampling_params.dtype is None:
if self.is_chat_model:
if s.api_num_spec_tokens is None:
if s.num_api_spec_tokens is None:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported if api speculative execution is off. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
)
prompt = s.messages_
else:
return self._prepare_spec_execution(sampling_params,
s.api_num_spec_tokens, spec_var_name)
s.num_api_spec_tokens, spec_var_name)
else:
prompt = s.text_
......@@ -230,7 +230,7 @@ class OpenAI(BaseBackend):
self,
s: StreamExecutor,
):
if s.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
return
comp = ""
......
......@@ -66,7 +66,7 @@ def run_program(
default_sampling_para,
chat_template=None,
stream=stream,
api_num_spec_tokens=program.api_num_spec_tokens,
num_api_spec_tokens=program.num_api_spec_tokens,
)
state = ProgramState(stream_executor)
......@@ -178,7 +178,7 @@ class StreamExecutor:
default_sampling_para,
chat_template,
stream,
api_num_spec_tokens=None,
num_api_spec_tokens=None,
use_thread=True,
):
self.sid = uuid.uuid4().hex
......@@ -210,7 +210,7 @@ class StreamExecutor:
self.fork_start_text_pos = None
# For speculative execution
self.api_num_spec_tokens = api_num_spec_tokens
self.num_api_spec_tokens = num_api_spec_tokens
self.speculated_text = ""
# Worker thread
......@@ -399,7 +399,7 @@ class StreamExecutor:
if (
self.cur_role == "assistant"
and self.api_num_spec_tokens is not None
and self.num_api_spec_tokens is not None
and self.backend.is_chat_model
and not prefix
):
......@@ -444,7 +444,7 @@ class StreamExecutor:
nonlocal meta_info
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.api_num_spec_tokens
sampling_params.max_new_tokens, self.num_api_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
......@@ -490,7 +490,7 @@ class StreamExecutor:
name = expr.name
if not self.stream:
if self.api_num_spec_tokens is None:
if self.num_api_spec_tokens is None:
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
......@@ -517,7 +517,7 @@ class StreamExecutor:
self.variable_event[name].set()
else:
assert (
self.api_num_spec_tokens is None
self.num_api_spec_tokens is None
), "stream is not supported with api speculative execution"
generator = self.backend.generate_stream(
self, sampling_params=sampling_params
......@@ -580,7 +580,7 @@ class StreamExecutor:
def _execute_role_end(self, expr: SglRoleEnd):
if (
self.cur_role == "assistant"
and self.api_num_spec_tokens is not None
and self.num_api_spec_tokens is not None
and self.backend.is_chat_model
):
# Execute the stored lazy generation calls
......
......@@ -97,9 +97,9 @@ class SglSamplingParams:
class SglFunction:
def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None):
def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
self.func = func
self.api_num_spec_tokens = api_num_spec_tokens
self.num_api_spec_tokens = num_api_spec_tokens
self.bind_arguments = bind_arguments or {}
self.pin_prefix_rid = None
......
......@@ -353,7 +353,7 @@ def test_regex():
def test_completion_speculative():
@sgl.function(api_num_spec_tokens=64)
@sgl.function(num_api_spec_tokens=64)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
......@@ -384,7 +384,7 @@ def test_completion_speculative():
def test_chat_completion_speculative():
@sgl.function(api_num_spec_tokens=256)
@sgl.function(num_api_spec_tokens=256)
def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment