Unverified Commit 73cf6834 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Support `stop_token_ids` in sglang API (#1092)

parent 1c2b5f52
...@@ -62,6 +62,7 @@ def gen( ...@@ -62,6 +62,7 @@ def gen(
name: Optional[str] = None, name: Optional[str] = None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
...@@ -98,6 +99,7 @@ def gen( ...@@ -98,6 +99,7 @@ def gen(
name, name,
max_tokens, max_tokens,
stop, stop,
stop_token_ids,
temperature, temperature,
top_p, top_p,
top_k, top_k,
...@@ -117,6 +119,7 @@ def gen_int( ...@@ -117,6 +119,7 @@ def gen_int(
name: Optional[str] = None, name: Optional[str] = None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
...@@ -132,6 +135,7 @@ def gen_int( ...@@ -132,6 +135,7 @@ def gen_int(
name, name,
max_tokens, max_tokens,
stop, stop,
stop_token_ids,
temperature, temperature,
top_p, top_p,
top_k, top_k,
...@@ -151,6 +155,7 @@ def gen_string( ...@@ -151,6 +155,7 @@ def gen_string(
name: Optional[str] = None, name: Optional[str] = None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
...@@ -166,6 +171,7 @@ def gen_string( ...@@ -166,6 +171,7 @@ def gen_string(
name, name,
max_tokens, max_tokens,
stop, stop,
stop_token_ids,
temperature, temperature,
top_p, top_p,
top_k, top_k,
......
...@@ -20,7 +20,6 @@ from sglang.lang.ir import ( ...@@ -20,7 +20,6 @@ from sglang.lang.ir import (
SglConstantText, SglConstantText,
SglExpr, SglExpr,
SglExprList, SglExprList,
SglFunction,
SglGen, SglGen,
SglImage, SglImage,
SglRoleBegin, SglRoleBegin,
...@@ -181,8 +180,10 @@ class StreamExecutor: ...@@ -181,8 +180,10 @@ class StreamExecutor:
num_api_spec_tokens=None, num_api_spec_tokens=None,
use_thread=True, use_thread=True,
): ):
from sglang.lang.backend.base_backend import BaseBackend
self.sid = uuid.uuid4().hex self.sid = uuid.uuid4().hex
self.backend = backend self.backend: BaseBackend = backend
self.arguments: Dict[str, Any] = arguments self.arguments: Dict[str, Any] = arguments
self.default_sampling_para = default_sampling_para self.default_sampling_para = default_sampling_para
self.stream = stream self.stream = stream
...@@ -658,6 +659,7 @@ class StreamExecutor: ...@@ -658,6 +659,7 @@ class StreamExecutor:
for item in [ for item in [
"max_new_tokens", "max_new_tokens",
"stop", "stop",
"stop_token_ids",
"temperature", "temperature",
"top_p", "top_p",
"top_k", "top_k",
......
...@@ -18,6 +18,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg ...@@ -18,6 +18,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
class SglSamplingParams: class SglSamplingParams:
max_new_tokens: int = 128 max_new_tokens: int = 128
stop: Union[str, List[str]] = () stop: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = ()
temperature: float = 1.0 temperature: float = 1.0
top_p: float = 1.0 top_p: float = 1.0
top_k: int = -1 # -1 means disable top_k: int = -1 # -1 means disable
...@@ -37,6 +38,7 @@ class SglSamplingParams: ...@@ -37,6 +38,7 @@ class SglSamplingParams:
return SglSamplingParams( return SglSamplingParams(
self.max_new_tokens, self.max_new_tokens,
self.stop, self.stop,
self.stop_token_ids,
self.temperature, self.temperature,
self.top_p, self.top_p,
self.top_k, self.top_k,
...@@ -108,6 +110,7 @@ class SglSamplingParams: ...@@ -108,6 +110,7 @@ class SglSamplingParams:
return { return {
"max_new_tokens": self.max_new_tokens, "max_new_tokens": self.max_new_tokens,
"stop": self.stop, "stop": self.stop,
"stop_token_ids": self.stop_token_ids,
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p, "top_p": self.top_p,
"top_k": self.top_k, "top_k": self.top_k,
...@@ -141,7 +144,8 @@ class SglFunction: ...@@ -141,7 +144,8 @@ class SglFunction:
self, self,
*args, *args,
max_new_tokens: int = 128, max_new_tokens: int = 128,
stop: Union[str, List[str]] = (), stop: Union[str, List[str]] = [],
stop_token_ids: Optional[List[int]] = [],
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
...@@ -161,6 +165,7 @@ class SglFunction: ...@@ -161,6 +165,7 @@ class SglFunction:
default_sampling_para = SglSamplingParams( default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
...@@ -181,6 +186,7 @@ class SglFunction: ...@@ -181,6 +186,7 @@ class SglFunction:
*, *,
max_new_tokens: int = 128, max_new_tokens: int = 128,
stop: Union[str, List[str]] = (), stop: Union[str, List[str]] = (),
stop_token_ids: Optional[List[int]] = [],
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
...@@ -218,6 +224,7 @@ class SglFunction: ...@@ -218,6 +224,7 @@ class SglFunction:
default_sampling_para = SglSamplingParams( default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
...@@ -397,6 +404,7 @@ class SglGen(SglExpr): ...@@ -397,6 +404,7 @@ class SglGen(SglExpr):
name: Optional[str] = None, name: Optional[str] = None,
max_new_tokens: Optional[int] = None, max_new_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
...@@ -416,6 +424,7 @@ class SglGen(SglExpr): ...@@ -416,6 +424,7 @@ class SglGen(SglExpr):
self.sampling_params = SglSamplingParams( self.sampling_params = SglSamplingParams(
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
......
...@@ -235,10 +235,12 @@ class Req: ...@@ -235,10 +235,12 @@ class Req:
return return
last_token_id = self.output_ids[-1] last_token_id = self.output_ids[-1]
if self.tokenizer is None:
matched_eos = last_token_id in self.sampling_params.stop_token_ids matched_eos = last_token_id in self.sampling_params.stop_token_ids
else:
matched_eos = last_token_id == self.tokenizer.eos_token_id if self.tokenizer is not None:
matched_eos |= last_token_id == self.tokenizer.eos_token_id
if matched_eos and not self.sampling_params.ignore_eos: if matched_eos and not self.sampling_params.ignore_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return return
......
...@@ -106,13 +106,16 @@ def test_decode_json_regex(): ...@@ -106,13 +106,16 @@ def test_decode_json_regex():
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
s += "Generate a JSON object to describe the basic city information of Paris.\n" s += "Generate a JSON object to describe the basic city information of Paris.\n"
s += "Here are the JSON object:\n"
# NOTE: we recommend using dtype gen or whole regex string to control the output
with s.var_scope("json_output"): with s.var_scope("json_output"):
s += "{\n" s += "{\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STR + ",") + "\n" s += ' "name": ' + sgl.gen(regex=REGEX_STR) + ",\n"
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n" s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n"
s += "}" s += "}"
ret = decode_json.run(temperature=0.0) ret = decode_json.run(temperature=0.0)
......
...@@ -84,7 +84,7 @@ class TestServingThroughput(unittest.TestCase): ...@@ -84,7 +84,7 @@ class TestServingThroughput(unittest.TestCase):
if os.getenv("SGLANG_IS_IN_CI", "false") == "true": if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
# A100 (PCIE) performance # A100 (PCIE) performance
assert res["output_throughput"] > 940 assert res["output_throughput"] > 930
def test_default_with_chunked_prefill(self): def test_default_with_chunked_prefill(self):
res = self.run_test( res = self.run_test(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment