Unverified Commit ced77c66 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename api_num_spec_tokens -> num_api_spec_tokens (#458)

parent 8dbdc018
...@@ -19,7 +19,7 @@ import sglang as sgl ...@@ -19,7 +19,7 @@ import sglang as sgl
from sglang import function, set_default_backend, OpenAI from sglang import function, set_default_backend, OpenAI
@function(api_num_spec_tokens=256) @function(num_api_spec_tokens=256)
def gen_character_spec(s): def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.") s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:") s += sgl.user("Construct a character within the following format:")
...@@ -28,7 +28,7 @@ def gen_character_spec(s): ...@@ -28,7 +28,7 @@ def gen_character_spec(s):
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
@function(api_num_spec_tokens=256) @function(num_api_spec_tokens=256)
def gen_character_spec_no_few_shot(s): def gen_character_spec_no_few_shot(s):
s += sgl.user("Construct a character. For each field stop with a newline\n") s += sgl.user("Construct a character. For each field stop with a newline\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
...@@ -41,7 +41,7 @@ def gen_character_normal(s): ...@@ -41,7 +41,7 @@ def gen_character_normal(s):
s += sgl.assistant(sgl.gen("answer", max_tokens=64)) s += sgl.assistant(sgl.gen("answer", max_tokens=64))
@function(api_num_spec_tokens=1024) @function(num_api_spec_tokens=1024)
def multi_turn_question(s, question_1, question_2): def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.") s += sgl.system("You are a helpful assistant.")
s += sgl.user("Answer questions in the following format:") s += sgl.user("Answer questions in the following format:")
......
...@@ -5,7 +5,7 @@ python3 openai_speculative.py ...@@ -5,7 +5,7 @@ python3 openai_speculative.py
from sglang import function, gen, set_default_backend, OpenAI from sglang import function, gen, set_default_backend, OpenAI
@function(api_num_spec_tokens=64) @function(num_api_spec_tokens=64)
def gen_character_spec(s): def gen_character_spec(s):
s += "Construct a character within the following format:\n" s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
...@@ -23,7 +23,7 @@ def gen_character_no_spec(s): ...@@ -23,7 +23,7 @@ def gen_character_no_spec(s):
s += "\nJob:" + gen("job", stop="\n") + "\n" s += "\nJob:" + gen("job", stop="\n") + "\n"
@function(api_num_spec_tokens=64) @function(num_api_spec_tokens=64)
def gen_character_spec_no_few_shot(s): def gen_character_spec_no_few_shot(s):
# s += "Construct a character with name, birthday, and job:\n" # s += "Construct a character with name, birthday, and job:\n"
s += "Construct a character:\n" s += "Construct a character:\n"
......
...@@ -20,13 +20,13 @@ from sglang.lang.ir import ( ...@@ -20,13 +20,13 @@ from sglang.lang.ir import (
def function( def function(
func: Optional[Callable] = None, api_num_spec_tokens: Optional[int] = None func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
): ):
if func: if func:
return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens) return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
def decorator(func): def decorator(func):
return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens) return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
return decorator return decorator
......
...@@ -106,12 +106,12 @@ class OpenAI(BaseBackend): ...@@ -106,12 +106,12 @@ class OpenAI(BaseBackend):
return self.chat_template return self.chat_template
def _prepare_spec_execution(self, sampling_params: SglSamplingParams, def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
api_num_spec_tokens: int, spec_var_name: str): num_api_spec_tokens: int, spec_var_name: str):
if "max_tokens" not in self.spec_kwargs: if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = api_num_spec_tokens self.spec_kwargs["max_tokens"] = num_api_spec_tokens
else: else:
assert ( assert (
self.spec_kwargs["max_tokens"] == api_num_spec_tokens self.spec_kwargs["max_tokens"] == num_api_spec_tokens
) )
params = sampling_params.to_openai_kwargs() params = sampling_params.to_openai_kwargs()
...@@ -142,17 +142,17 @@ class OpenAI(BaseBackend): ...@@ -142,17 +142,17 @@ class OpenAI(BaseBackend):
): ):
if sampling_params.dtype is None: if sampling_params.dtype is None:
if self.is_chat_model: if self.is_chat_model:
if s.api_num_spec_tokens is None: if s.num_api_spec_tokens is None:
if not s.text_.endswith(self.chat_prefix): if not s.text_.endswith(self.chat_prefix):
raise RuntimeError( raise RuntimeError(
"This use case is not supported if api speculative execution is off. " "This use case is not supported if api speculative execution is off. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. " "For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)." "Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
) )
prompt = s.messages_ prompt = s.messages_
else: else:
return self._prepare_spec_execution(sampling_params, return self._prepare_spec_execution(sampling_params,
s.api_num_spec_tokens, spec_var_name) s.num_api_spec_tokens, spec_var_name)
else: else:
prompt = s.text_ prompt = s.text_
...@@ -230,7 +230,7 @@ class OpenAI(BaseBackend): ...@@ -230,7 +230,7 @@ class OpenAI(BaseBackend):
self, self,
s: StreamExecutor, s: StreamExecutor,
): ):
if s.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix): if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
return return
comp = "" comp = ""
......
...@@ -66,7 +66,7 @@ def run_program( ...@@ -66,7 +66,7 @@ def run_program(
default_sampling_para, default_sampling_para,
chat_template=None, chat_template=None,
stream=stream, stream=stream,
api_num_spec_tokens=program.api_num_spec_tokens, num_api_spec_tokens=program.num_api_spec_tokens,
) )
state = ProgramState(stream_executor) state = ProgramState(stream_executor)
...@@ -178,7 +178,7 @@ class StreamExecutor: ...@@ -178,7 +178,7 @@ class StreamExecutor:
default_sampling_para, default_sampling_para,
chat_template, chat_template,
stream, stream,
api_num_spec_tokens=None, num_api_spec_tokens=None,
use_thread=True, use_thread=True,
): ):
self.sid = uuid.uuid4().hex self.sid = uuid.uuid4().hex
...@@ -210,7 +210,7 @@ class StreamExecutor: ...@@ -210,7 +210,7 @@ class StreamExecutor:
self.fork_start_text_pos = None self.fork_start_text_pos = None
# For speculative execution # For speculative execution
self.api_num_spec_tokens = api_num_spec_tokens self.num_api_spec_tokens = num_api_spec_tokens
self.speculated_text = "" self.speculated_text = ""
# Worker thread # Worker thread
...@@ -399,7 +399,7 @@ class StreamExecutor: ...@@ -399,7 +399,7 @@ class StreamExecutor:
if ( if (
self.cur_role == "assistant" self.cur_role == "assistant"
and self.api_num_spec_tokens is not None and self.num_api_spec_tokens is not None
and self.backend.is_chat_model and self.backend.is_chat_model
and not prefix and not prefix
): ):
...@@ -444,7 +444,7 @@ class StreamExecutor: ...@@ -444,7 +444,7 @@ class StreamExecutor:
nonlocal meta_info nonlocal meta_info
sampling_params.max_new_tokens = max( sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.api_num_spec_tokens sampling_params.max_new_tokens, self.num_api_spec_tokens
) )
sampling_params.stop = None sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate( self.speculated_text, meta_info = self.backend.generate(
...@@ -490,7 +490,7 @@ class StreamExecutor: ...@@ -490,7 +490,7 @@ class StreamExecutor:
name = expr.name name = expr.name
if not self.stream: if not self.stream:
if self.api_num_spec_tokens is None: if self.num_api_spec_tokens is None:
comp, meta_info = self.backend.generate( comp, meta_info = self.backend.generate(
self, self,
sampling_params=sampling_params, sampling_params=sampling_params,
...@@ -517,7 +517,7 @@ class StreamExecutor: ...@@ -517,7 +517,7 @@ class StreamExecutor:
self.variable_event[name].set() self.variable_event[name].set()
else: else:
assert ( assert (
self.api_num_spec_tokens is None self.num_api_spec_tokens is None
), "stream is not supported with api speculative execution" ), "stream is not supported with api speculative execution"
generator = self.backend.generate_stream( generator = self.backend.generate_stream(
self, sampling_params=sampling_params self, sampling_params=sampling_params
...@@ -580,7 +580,7 @@ class StreamExecutor: ...@@ -580,7 +580,7 @@ class StreamExecutor:
def _execute_role_end(self, expr: SglRoleEnd): def _execute_role_end(self, expr: SglRoleEnd):
if ( if (
self.cur_role == "assistant" self.cur_role == "assistant"
and self.api_num_spec_tokens is not None and self.num_api_spec_tokens is not None
and self.backend.is_chat_model and self.backend.is_chat_model
): ):
# Execute the stored lazy generation calls # Execute the stored lazy generation calls
......
...@@ -97,9 +97,9 @@ class SglSamplingParams: ...@@ -97,9 +97,9 @@ class SglSamplingParams:
class SglFunction: class SglFunction:
def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None): def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
self.func = func self.func = func
self.api_num_spec_tokens = api_num_spec_tokens self.num_api_spec_tokens = num_api_spec_tokens
self.bind_arguments = bind_arguments or {} self.bind_arguments = bind_arguments or {}
self.pin_prefix_rid = None self.pin_prefix_rid = None
......
...@@ -353,7 +353,7 @@ def test_regex(): ...@@ -353,7 +353,7 @@ def test_regex():
def test_completion_speculative(): def test_completion_speculative():
@sgl.function(api_num_spec_tokens=64) @sgl.function(num_api_spec_tokens=64)
def gen_character_spec(s): def gen_character_spec(s):
s += "Construct a character within the following format:\n" s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
...@@ -384,7 +384,7 @@ def test_completion_speculative(): ...@@ -384,7 +384,7 @@ def test_completion_speculative():
def test_chat_completion_speculative(): def test_chat_completion_speculative():
@sgl.function(api_num_spec_tokens=256) @sgl.function(num_api_spec_tokens=256)
def gen_character_spec(s): def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.") s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:") s += sgl.user("Construct a character within the following format:")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment