Unverified Commit 47c606d3 authored by Glen Liu's avatar Glen Liu Committed by GitHub
Browse files

[Feature] support regex strings as a stopping condition (#10635)

parent 9fcf7306
......@@ -49,6 +49,7 @@ python -m sglang.launch_server --model-path <MODEL> --sampling-defaults openai
| max_new_tokens | `int = 128` | The maximum output length measured in tokens. |
| stop | `Optional[Union[str, List[str]]] = None` | One or multiple [stop words](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). Generation will stop if one of these words is sampled. |
| stop_token_ids | `Optional[List[int]] = None` | Provide stop words in the form of token IDs. Generation will stop if one of these token IDs is sampled. |
| stop_regex | `Optional[Union[str, List[str]]] = None` | Stop when hitting any of the regex patterns in this list |
| temperature | `float (model default; fallback 1.0)` | [Temperature](https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature) when sampling the next token. `temperature = 0` corresponds to greedy sampling, a higher temperature leads to more diversity. |
| top_p | `float (model default; fallback 1.0)` | [Top-p](https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_p) selects tokens from the smallest sorted set whose cumulative probability exceeds `top_p`. When `top_p = 1`, this reduces to unrestricted sampling from all tokens. |
| top_k | `int (model default; fallback -1)` | [Top-k](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) randomly selects from the `k` highest-probability tokens. |
......
......@@ -79,6 +79,7 @@ def gen(
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
......@@ -120,6 +121,7 @@ def gen(
n,
stop,
stop_token_ids,
stop_regex,
temperature,
top_p,
top_k,
......@@ -143,6 +145,7 @@ def gen_int(
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
......@@ -162,6 +165,7 @@ def gen_int(
n,
stop,
stop_token_ids,
stop_regex,
temperature,
top_p,
top_k,
......@@ -184,6 +188,7 @@ def gen_string(
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
......@@ -203,6 +208,7 @@ def gen_string(
n,
stop,
stop_token_ids,
stop_regex,
temperature,
top_p,
top_k,
......
......@@ -792,6 +792,7 @@ class StreamExecutor:
"n",
"stop",
"stop_token_ids",
"stop_regex",
"temperature",
"top_p",
"top_k",
......
......@@ -21,6 +21,7 @@ class SglSamplingParams:
n: int = 1
stop: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = ()
stop_regex: Optional[Union[str, List[str]]] = ()
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
......@@ -45,6 +46,7 @@ class SglSamplingParams:
self.n,
self.stop,
self.stop_token_ids,
self.stop_regex,
self.temperature,
self.top_p,
self.top_k,
......@@ -123,6 +125,7 @@ class SglSamplingParams:
"n": self.n,
"stop": self.stop,
"stop_token_ids": self.stop_token_ids,
"stop_regex": self.stop_regex,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
......@@ -161,6 +164,7 @@ class SglFunction:
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
......@@ -184,12 +188,15 @@ class SglFunction:
stop = []
if stop_token_ids is None:
stop_token_ids = []
if stop_regex is None:
stop_regex = []
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
stop_regex=stop_regex,
temperature=temperature,
top_p=top_p,
top_k=top_k,
......@@ -221,6 +228,7 @@ class SglFunction:
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
......@@ -243,6 +251,8 @@ class SglFunction:
stop = []
if stop_token_ids is None:
stop_token_ids = []
if stop_regex is None:
stop_regex = []
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
......@@ -267,6 +277,7 @@ class SglFunction:
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
stop_regex=stop_regex,
temperature=temperature,
top_p=top_p,
top_k=top_k,
......@@ -451,6 +462,7 @@ class SglGen(SglExpr):
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
......@@ -474,6 +486,7 @@ class SglGen(SglExpr):
min_new_tokens=min_new_tokens,
n=n,
stop=stop,
stop_regex=stop_regex,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
......
......@@ -221,6 +221,7 @@ class CompletionRequest(BaseModel):
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = None
stop_regex: Optional[Union[str, List[str]]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
skip_special_tokens: bool = True
......@@ -474,6 +475,7 @@ class ChatCompletionRequest(BaseModel):
ebnf: Optional[str] = None
repetition_penalty: Optional[float] = None
stop_token_ids: Optional[List[int]] = None
stop_regex: Optional[Union[str, List[str]]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
continue_final_message: bool = False
......@@ -602,6 +604,7 @@ class ChatCompletionRequest(BaseModel):
"min_new_tokens": self.min_tokens,
"stop": stop,
"stop_token_ids": self.stop_token_ids,
"stop_regex": self.stop_regex,
"top_p": get_param("top_p"),
"top_k": get_param("top_k"),
"min_p": get_param("min_p"),
......
......@@ -123,6 +123,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"stop_regex": request.stop_regex,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
......
......@@ -36,6 +36,7 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
import copy
import dataclasses
import logging
import re
import threading
import time
from enum import Enum, auto
......@@ -154,6 +155,18 @@ class FINISH_MATCHED_STR(BaseFinishReason):
}
class FINISHED_MATCHED_REGEX(BaseFinishReason):
def __init__(self, matched: str):
super().__init__()
self.matched = matched
def to_json(self):
return {
"type": "stop", # to match OpenAI API's return value
"matched": self.matched,
}
class FINISH_LENGTH(BaseFinishReason):
def __init__(self, length: int):
super().__init__()
......@@ -735,8 +748,17 @@ class Req:
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
def tail_str(self) -> str:
tail_len = self.sampling_params.stop_str_max_len + 1
tail_len = min(tail_len, len(self.output_ids))
# Check stop strings and stop regex patterns together
if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
max_len_tail_str = max(
self.sampling_params.stop_str_max_len + 1,
self.sampling_params.stop_regex_max_len + 1,
)
tail_len = min((max_len_tail_str + 1), len(self.output_ids))
return self.tokenizer.decode(self.output_ids[-tail_len:])
def check_match_stop_str_prefix(self) -> bool:
......@@ -817,14 +839,27 @@ class Req:
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
return
# Check stop strings
if len(self.sampling_params.stop_strs) > 0:
if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
tail_str = self.tail_str()
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
# Check stop strings
if len(self.sampling_params.stop_strs) > 0:
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
# Check stop regex
if len(self.sampling_params.stop_regex_strs) > 0:
for stop_regex_str in self.sampling_params.stop_regex_strs:
if re.search(stop_regex_str, tail_str):
self.finished_reason = FINISHED_MATCHED_REGEX(
matched=stop_regex_str
)
return
def reset_for_retract(self):
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
......
......@@ -13,6 +13,8 @@
# ==============================================================================
"""Sampling parameters for text generation."""
import logging
import sre_parse
from typing import Any, Dict, List, Optional, Union
from sglang.srt.utils import get_bool_env_var
......@@ -20,6 +22,8 @@ from sglang.srt.utils import get_bool_env_var
_SAMPLING_EPS = 1e-6
TOP_K_ALL = 1 << 30
logger = logging.getLogger(__name__)
class SamplingParams:
"""
......@@ -35,6 +39,7 @@ class SamplingParams:
max_new_tokens: int = 128,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
......@@ -63,6 +68,7 @@ class SamplingParams:
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.stop_regex_strs = stop_regex
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
......@@ -170,3 +176,67 @@ class SamplingParams:
else:
stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len
# Process stop regex strings
if self.stop_regex_strs is None:
self.stop_regex_strs = []
self.stop_regex_max_len = 0
else:
if isinstance(self.stop_regex_strs, str):
self.stop_regex_strs = [self.stop_regex_strs]
stop_regex_max_len = 0
for stop_regex in self.stop_regex_strs:
stop_regex_max_len = max(
stop_regex_max_len, get_max_seq_length(stop_regex)
)
self.stop_regex_max_len = stop_regex_max_len
# This function gets a strict upperbound on the maximum number of tokens that would need
# to be buffered to match the input regex string
# NOTE: in the worst case, one character that needs to be buffered corresponds to one
# token
def get_max_seq_length(regex_str: str):
return _max_length_from_subpattern(sre_parse.parse(regex_str))
MAX_LEN = 2**30
def _max_length_from_subpattern(subpattern: sre_parse.SubPattern):
total = 0
for token, value in subpattern:
if token in {
sre_parse.LITERAL, # `value` is any one character
sre_parse.IN, # Any character within `value`
sre_parse.ANY, # "."
}:
total += 1
elif token == sre_parse.SUBPATTERN:
# EG: (a\d+) ->
# [(SUBPATTERN,
# (1, 0, 0, [(LITERAL, 97),
# (MAX_REPEAT, (1, MAXREPEAT, [(IN, [(CATEGORY, CATEGORY_DIGIT)])]))]))]
_, _, _, inner_subpattern = value
total += _max_length_from_subpattern(inner_subpattern)
elif token == sre_parse.BRANCH:
_, branches = value
total += max(_max_length_from_subpattern(branch) for branch in branches)
elif token in {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT}:
_, max_num_repeat, inner_subpattern = value
if max_num_repeat == sre_parse.MAXREPEAT:
total += MAX_LEN
else:
total += max_num_repeat * _max_length_from_subpattern(inner_subpattern)
elif token == sre_parse.AT:
# These are zero-width assertions like ^, $, and \b that don't add to the max
# length
total += 0
else:
logger.warning(f"Got unhandled regex token: {token}")
total += MAX_LEN
return total
......@@ -3,6 +3,7 @@ import unittest
import requests
from sglang.srt.sampling.sampling_params import MAX_LEN, get_max_seq_length
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
......@@ -40,6 +41,7 @@ class TestMatchedStop(CustomTestCase):
prompt=MANY_NEW_TOKENS_PROMPT,
max_tokens=1,
stop=None,
stop_regex=None,
finish_reason=None,
matched_stop=None,
):
......@@ -54,6 +56,9 @@ class TestMatchedStop(CustomTestCase):
if stop is not None:
payload["stop"] = stop
if stop_regex is not None:
payload["stop_regex"] = stop_regex
response_completions = requests.post(
self.base_url + "/v1/completions",
json=payload,
......@@ -71,6 +76,7 @@ class TestMatchedStop(CustomTestCase):
prompt=MANY_NEW_TOKENS_PROMPT,
max_tokens=1,
stop=None,
stop_regex=None,
finish_reason=None,
matched_stop=None,
):
......@@ -88,6 +94,9 @@ class TestMatchedStop(CustomTestCase):
if stop is not None:
chat_payload["stop"] = stop
if stop_regex is not None:
chat_payload["stop_regex"] = stop_regex
response_chat = requests.post(
self.base_url + "/v1/chat/completions",
json=chat_payload,
......@@ -106,6 +115,30 @@ class TestMatchedStop(CustomTestCase):
max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n"
)
def test_finish_stop_regex_str(self):
STOP_REGEX_STR = r"and|or"
self.run_completions_generation(
max_tokens=1000,
stop_regex=STOP_REGEX_STR,
finish_reason="stop",
matched_stop=STOP_REGEX_STR,
)
self.run_chat_completions_generation(
max_tokens=1000,
stop_regex=STOP_REGEX_STR,
finish_reason="stop",
matched_stop=STOP_REGEX_STR,
)
# Match a complete sentence
STOP_REGEX_STR_SENTENCE = r"[.!?]\s*$"
self.run_chat_completions_generation(
max_tokens=1000,
stop_regex=STOP_REGEX_STR_SENTENCE,
finish_reason="stop",
matched_stop=STOP_REGEX_STR_SENTENCE,
)
def test_finish_stop_eos(self):
llama_format_prompt = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
......@@ -136,5 +169,53 @@ class TestMatchedStop(CustomTestCase):
)
class TestRegexPatternMaxLength(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.regex_str_to_max_len = {
"((ab|cd(e|f){2}){3,5}g|hij)*k": MAX_LEN,
# - '*' → infinite tokens need to be stored
"abc*?k": MAX_LEN,
# - '*?' → infinite tokens still need to be stored even if lazy matching used
"^spec(foo|at)$": 7,
# - '^' and '$' don't add any characters to the max length
# "spec" → 4
# "(foo|at)" → max(3, 2) = 3
# Whole regex = 7
"(a(bca|de(fg|hi){2,3})j){2}kl": 22,
# - Innermost alt: "fg" vs "hi" → 2
# - Repeat {2,3}: max = 3 * 2 = 6
# - Inner group "de(...)": 2 (for "de") + 6 = 8.
# - "bca" or "de(...)" → max(3, 8) = 8
# - Whole group: "a" (1) + group (8) + "j"(1) = 10
# - Repeat {2} → 20
# - Add "kl"(2) → 22
"(foo(bar|baz(qux){1,2}))|(x(yz){5,10})": 21,
# Branch 1:
# "foo"(3) + max("bar"(3), "baz"(3)+"qux"{2} = 3 + 6 = 9) = 3 + 9 = 12
# Branch 2:
# "x"(1) + "yz"{10} = 1 + 20 =21
# Whole regex = max(12, 21) = 21
"(((a|bc){1,3}(d(e|f){2}|gh){2,4})|(ijk|lmp(no|p){3})){5}": 90,
# Branch A:
# (a|bc){1,3} → max = 3 * 2 = 6
# Inside: d(e|f){2} = 1 + 2 * 1 = 3 vs gh = 2 → max = 3
# Repeat {2,4} → 4 * 3 = 12
# Branch A total = 18
# Branch B:
# "ijk"(3) vs "lmp(no|p){3}" = 3 + 3 * max(2, 1) = 3 + 6 = 9 → max = 9
# Branch B total = 9
# Whole outer alt = max(18, 9) = 18
# Repeat {5} → 90
}
def test_get_max_length(self):
for regex_str, max_len in self.regex_str_to_max_len.items():
if max_len == MAX_LEN:
self.assertGreaterEqual(get_max_seq_length(regex_str), MAX_LEN)
else:
self.assertEqual(get_max_seq_length(regex_str), max_len)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment