"src/vscode:/vscode.git/clone" did not exist on "0a0fe69aa6a11b7723e83ca9e049e6096839ad4d"
Unverified Commit 47c606d3 authored by Glen Liu's avatar Glen Liu Committed by GitHub
Browse files

[Feature] support regex strings as a stopping condition (#10635)

parent 9fcf7306
...@@ -49,6 +49,7 @@ python -m sglang.launch_server --model-path <MODEL> --sampling-defaults openai ...@@ -49,6 +49,7 @@ python -m sglang.launch_server --model-path <MODEL> --sampling-defaults openai
| max_new_tokens | `int = 128` | The maximum output length measured in tokens. | | max_new_tokens | `int = 128` | The maximum output length measured in tokens. |
| stop | `Optional[Union[str, List[str]]] = None` | One or multiple [stop words](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). Generation will stop if one of these words is sampled. | | stop | `Optional[Union[str, List[str]]] = None` | One or multiple [stop words](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). Generation will stop if one of these words is sampled. |
| stop_token_ids | `Optional[List[int]] = None` | Provide stop words in the form of token IDs. Generation will stop if one of these token IDs is sampled. | | stop_token_ids | `Optional[List[int]] = None` | Provide stop words in the form of token IDs. Generation will stop if one of these token IDs is sampled. |
| stop_regex | `Optional[Union[str, List[str]]] = None` | Stop when hitting any of the regex patterns in this list |
| temperature | `float (model default; fallback 1.0)` | [Temperature](https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature) when sampling the next token. `temperature = 0` corresponds to greedy sampling, a higher temperature leads to more diversity. | | temperature | `float (model default; fallback 1.0)` | [Temperature](https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature) when sampling the next token. `temperature = 0` corresponds to greedy sampling, a higher temperature leads to more diversity. |
| top_p | `float (model default; fallback 1.0)` | [Top-p](https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_p) selects tokens from the smallest sorted set whose cumulative probability exceeds `top_p`. When `top_p = 1`, this reduces to unrestricted sampling from all tokens. | | top_p | `float (model default; fallback 1.0)` | [Top-p](https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_p) selects tokens from the smallest sorted set whose cumulative probability exceeds `top_p`. When `top_p = 1`, this reduces to unrestricted sampling from all tokens. |
| top_k | `int (model default; fallback -1)` | [Top-k](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) randomly selects from the `k` highest-probability tokens. | | top_k | `int (model default; fallback -1)` | [Top-k](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) randomly selects from the `k` highest-probability tokens. |
......
...@@ -79,6 +79,7 @@ def gen( ...@@ -79,6 +79,7 @@ def gen(
n: Optional[int] = None, n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
...@@ -120,6 +121,7 @@ def gen( ...@@ -120,6 +121,7 @@ def gen(
n, n,
stop, stop,
stop_token_ids, stop_token_ids,
stop_regex,
temperature, temperature,
top_p, top_p,
top_k, top_k,
...@@ -143,6 +145,7 @@ def gen_int( ...@@ -143,6 +145,7 @@ def gen_int(
n: Optional[int] = None, n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
...@@ -162,6 +165,7 @@ def gen_int( ...@@ -162,6 +165,7 @@ def gen_int(
n, n,
stop, stop,
stop_token_ids, stop_token_ids,
stop_regex,
temperature, temperature,
top_p, top_p,
top_k, top_k,
...@@ -184,6 +188,7 @@ def gen_string( ...@@ -184,6 +188,7 @@ def gen_string(
n: Optional[int] = None, n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
...@@ -203,6 +208,7 @@ def gen_string( ...@@ -203,6 +208,7 @@ def gen_string(
n, n,
stop, stop,
stop_token_ids, stop_token_ids,
stop_regex,
temperature, temperature,
top_p, top_p,
top_k, top_k,
......
...@@ -792,6 +792,7 @@ class StreamExecutor: ...@@ -792,6 +792,7 @@ class StreamExecutor:
"n", "n",
"stop", "stop",
"stop_token_ids", "stop_token_ids",
"stop_regex",
"temperature", "temperature",
"top_p", "top_p",
"top_k", "top_k",
......
...@@ -21,6 +21,7 @@ class SglSamplingParams: ...@@ -21,6 +21,7 @@ class SglSamplingParams:
n: int = 1 n: int = 1
stop: Union[str, List[str]] = () stop: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = () stop_token_ids: Optional[List[int]] = ()
stop_regex: Optional[Union[str, List[str]]] = ()
temperature: float = 1.0 temperature: float = 1.0
top_p: float = 1.0 top_p: float = 1.0
top_k: int = -1 # -1 means disable top_k: int = -1 # -1 means disable
...@@ -45,6 +46,7 @@ class SglSamplingParams: ...@@ -45,6 +46,7 @@ class SglSamplingParams:
self.n, self.n,
self.stop, self.stop,
self.stop_token_ids, self.stop_token_ids,
self.stop_regex,
self.temperature, self.temperature,
self.top_p, self.top_p,
self.top_k, self.top_k,
...@@ -123,6 +125,7 @@ class SglSamplingParams: ...@@ -123,6 +125,7 @@ class SglSamplingParams:
"n": self.n, "n": self.n,
"stop": self.stop, "stop": self.stop,
"stop_token_ids": self.stop_token_ids, "stop_token_ids": self.stop_token_ids,
"stop_regex": self.stop_regex,
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p, "top_p": self.top_p,
"top_k": self.top_k, "top_k": self.top_k,
...@@ -161,6 +164,7 @@ class SglFunction: ...@@ -161,6 +164,7 @@ class SglFunction:
n: int = 1, n: int = 1,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
...@@ -184,12 +188,15 @@ class SglFunction: ...@@ -184,12 +188,15 @@ class SglFunction:
stop = [] stop = []
if stop_token_ids is None: if stop_token_ids is None:
stop_token_ids = [] stop_token_ids = []
if stop_regex is None:
stop_regex = []
default_sampling_para = SglSamplingParams( default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
n=n, n=n,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
stop_regex=stop_regex,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
...@@ -221,6 +228,7 @@ class SglFunction: ...@@ -221,6 +228,7 @@ class SglFunction:
n: int = 1, n: int = 1,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
...@@ -243,6 +251,8 @@ class SglFunction: ...@@ -243,6 +251,8 @@ class SglFunction:
stop = [] stop = []
if stop_token_ids is None: if stop_token_ids is None:
stop_token_ids = [] stop_token_ids = []
if stop_regex is None:
stop_regex = []
assert isinstance(batch_kwargs, (list, tuple)) assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0: if len(batch_kwargs) == 0:
...@@ -267,6 +277,7 @@ class SglFunction: ...@@ -267,6 +277,7 @@ class SglFunction:
n=n, n=n,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
stop_regex=stop_regex,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
...@@ -451,6 +462,7 @@ class SglGen(SglExpr): ...@@ -451,6 +462,7 @@ class SglGen(SglExpr):
n: Optional[int] = None, n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
...@@ -474,6 +486,7 @@ class SglGen(SglExpr): ...@@ -474,6 +486,7 @@ class SglGen(SglExpr):
min_new_tokens=min_new_tokens, min_new_tokens=min_new_tokens,
n=n, n=n,
stop=stop, stop=stop,
stop_regex=stop_regex,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
......
...@@ -221,6 +221,7 @@ class CompletionRequest(BaseModel): ...@@ -221,6 +221,7 @@ class CompletionRequest(BaseModel):
ebnf: Optional[str] = None ebnf: Optional[str] = None
repetition_penalty: float = 1.0 repetition_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = None stop_token_ids: Optional[List[int]] = None
stop_regex: Optional[Union[str, List[str]]] = None
no_stop_trim: bool = False no_stop_trim: bool = False
ignore_eos: bool = False ignore_eos: bool = False
skip_special_tokens: bool = True skip_special_tokens: bool = True
...@@ -474,6 +475,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -474,6 +475,7 @@ class ChatCompletionRequest(BaseModel):
ebnf: Optional[str] = None ebnf: Optional[str] = None
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
stop_token_ids: Optional[List[int]] = None stop_token_ids: Optional[List[int]] = None
stop_regex: Optional[Union[str, List[str]]] = None
no_stop_trim: bool = False no_stop_trim: bool = False
ignore_eos: bool = False ignore_eos: bool = False
continue_final_message: bool = False continue_final_message: bool = False
...@@ -602,6 +604,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -602,6 +604,7 @@ class ChatCompletionRequest(BaseModel):
"min_new_tokens": self.min_tokens, "min_new_tokens": self.min_tokens,
"stop": stop, "stop": stop,
"stop_token_ids": self.stop_token_ids, "stop_token_ids": self.stop_token_ids,
"stop_regex": self.stop_regex,
"top_p": get_param("top_p"), "top_p": get_param("top_p"),
"top_k": get_param("top_k"), "top_k": get_param("top_k"),
"min_p": get_param("min_p"), "min_p": get_param("min_p"),
......
...@@ -123,6 +123,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -123,6 +123,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
"min_new_tokens": request.min_tokens, "min_new_tokens": request.min_tokens,
"stop": request.stop, "stop": request.stop,
"stop_token_ids": request.stop_token_ids, "stop_token_ids": request.stop_token_ids,
"stop_regex": request.stop_regex,
"top_p": request.top_p, "top_p": request.top_p,
"top_k": request.top_k, "top_k": request.top_k,
"min_p": request.min_p, "min_p": request.min_p,
......
...@@ -36,6 +36,7 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i ...@@ -36,6 +36,7 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
import copy import copy
import dataclasses import dataclasses
import logging import logging
import re
import threading import threading
import time import time
from enum import Enum, auto from enum import Enum, auto
...@@ -154,6 +155,18 @@ class FINISH_MATCHED_STR(BaseFinishReason): ...@@ -154,6 +155,18 @@ class FINISH_MATCHED_STR(BaseFinishReason):
} }
class FINISHED_MATCHED_REGEX(BaseFinishReason):
def __init__(self, matched: str):
super().__init__()
self.matched = matched
def to_json(self):
return {
"type": "stop", # to match OpenAI API's return value
"matched": self.matched,
}
class FINISH_LENGTH(BaseFinishReason): class FINISH_LENGTH(BaseFinishReason):
def __init__(self, length: int): def __init__(self, length: int):
super().__init__() super().__init__()
...@@ -735,8 +748,17 @@ class Req: ...@@ -735,8 +748,17 @@ class Req:
return self.surr_and_decode_ids, self.read_offset - self.surr_offset return self.surr_and_decode_ids, self.read_offset - self.surr_offset
def tail_str(self) -> str: def tail_str(self) -> str:
tail_len = self.sampling_params.stop_str_max_len + 1 # Check stop strings and stop regex patterns together
tail_len = min(tail_len, len(self.output_ids)) if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
max_len_tail_str = max(
self.sampling_params.stop_str_max_len + 1,
self.sampling_params.stop_regex_max_len + 1,
)
tail_len = min((max_len_tail_str + 1), len(self.output_ids))
return self.tokenizer.decode(self.output_ids[-tail_len:]) return self.tokenizer.decode(self.output_ids[-tail_len:])
def check_match_stop_str_prefix(self) -> bool: def check_match_stop_str_prefix(self) -> bool:
...@@ -817,14 +839,27 @@ class Req: ...@@ -817,14 +839,27 @@ class Req:
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened") self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
return return
# Check stop strings if (
if len(self.sampling_params.stop_strs) > 0: len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
tail_str = self.tail_str() tail_str = self.tail_str()
for stop_str in self.sampling_params.stop_strs: # Check stop strings
if stop_str in tail_str or stop_str in self.decoded_text: if len(self.sampling_params.stop_strs) > 0:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) for stop_str in self.sampling_params.stop_strs:
return if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
# Check stop regex
if len(self.sampling_params.stop_regex_strs) > 0:
for stop_regex_str in self.sampling_params.stop_regex_strs:
if re.search(stop_regex_str, tail_str):
self.finished_reason = FINISHED_MATCHED_REGEX(
matched=stop_regex_str
)
return
def reset_for_retract(self): def reset_for_retract(self):
self.prefix_indices = torch.empty((0,), dtype=torch.int64) self.prefix_indices = torch.empty((0,), dtype=torch.int64)
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# ============================================================================== # ==============================================================================
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
import logging
import sre_parse
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var
...@@ -20,6 +22,8 @@ from sglang.srt.utils import get_bool_env_var ...@@ -20,6 +22,8 @@ from sglang.srt.utils import get_bool_env_var
_SAMPLING_EPS = 1e-6 _SAMPLING_EPS = 1e-6
TOP_K_ALL = 1 << 30 TOP_K_ALL = 1 << 30
logger = logging.getLogger(__name__)
class SamplingParams: class SamplingParams:
""" """
...@@ -35,6 +39,7 @@ class SamplingParams: ...@@ -35,6 +39,7 @@ class SamplingParams:
max_new_tokens: int = 128, max_new_tokens: int = 128,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
...@@ -63,6 +68,7 @@ class SamplingParams: ...@@ -63,6 +68,7 @@ class SamplingParams:
self.stop_token_ids = set(stop_token_ids) self.stop_token_ids = set(stop_token_ids)
else: else:
self.stop_token_ids = None self.stop_token_ids = None
self.stop_regex_strs = stop_regex
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
...@@ -170,3 +176,67 @@ class SamplingParams: ...@@ -170,3 +176,67 @@ class SamplingParams:
else: else:
stop_str_max_len = max(stop_str_max_len, len(stop_str)) stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len self.stop_str_max_len = stop_str_max_len
# Process stop regex strings
if self.stop_regex_strs is None:
self.stop_regex_strs = []
self.stop_regex_max_len = 0
else:
if isinstance(self.stop_regex_strs, str):
self.stop_regex_strs = [self.stop_regex_strs]
stop_regex_max_len = 0
for stop_regex in self.stop_regex_strs:
stop_regex_max_len = max(
stop_regex_max_len, get_max_seq_length(stop_regex)
)
self.stop_regex_max_len = stop_regex_max_len
# This function gets a strict upperbound on the maximum number of tokens that would need
# to be buffered to match the input regex string
# NOTE: in the worst case, one character that needs to be buffered corresponds to one
# token
def get_max_seq_length(regex_str: str):
return _max_length_from_subpattern(sre_parse.parse(regex_str))
MAX_LEN = 2**30
def _max_length_from_subpattern(subpattern: sre_parse.SubPattern):
total = 0
for token, value in subpattern:
if token in {
sre_parse.LITERAL, # `value` is any one character
sre_parse.IN, # Any character within `value`
sre_parse.ANY, # "."
}:
total += 1
elif token == sre_parse.SUBPATTERN:
# EG: (a\d+) ->
# [(SUBPATTERN,
# (1, 0, 0, [(LITERAL, 97),
# (MAX_REPEAT, (1, MAXREPEAT, [(IN, [(CATEGORY, CATEGORY_DIGIT)])]))]))]
_, _, _, inner_subpattern = value
total += _max_length_from_subpattern(inner_subpattern)
elif token == sre_parse.BRANCH:
_, branches = value
total += max(_max_length_from_subpattern(branch) for branch in branches)
elif token in {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT}:
_, max_num_repeat, inner_subpattern = value
if max_num_repeat == sre_parse.MAXREPEAT:
total += MAX_LEN
else:
total += max_num_repeat * _max_length_from_subpattern(inner_subpattern)
elif token == sre_parse.AT:
# These are zero-width assertions like ^, $, and \b that don't add to the max
# length
total += 0
else:
logger.warning(f"Got unhandled regex token: {token}")
total += MAX_LEN
return total
...@@ -3,6 +3,7 @@ import unittest ...@@ -3,6 +3,7 @@ import unittest
import requests import requests
from sglang.srt.sampling.sampling_params import MAX_LEN, get_max_seq_length
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
...@@ -40,6 +41,7 @@ class TestMatchedStop(CustomTestCase): ...@@ -40,6 +41,7 @@ class TestMatchedStop(CustomTestCase):
prompt=MANY_NEW_TOKENS_PROMPT, prompt=MANY_NEW_TOKENS_PROMPT,
max_tokens=1, max_tokens=1,
stop=None, stop=None,
stop_regex=None,
finish_reason=None, finish_reason=None,
matched_stop=None, matched_stop=None,
): ):
...@@ -54,6 +56,9 @@ class TestMatchedStop(CustomTestCase): ...@@ -54,6 +56,9 @@ class TestMatchedStop(CustomTestCase):
if stop is not None: if stop is not None:
payload["stop"] = stop payload["stop"] = stop
if stop_regex is not None:
payload["stop_regex"] = stop_regex
response_completions = requests.post( response_completions = requests.post(
self.base_url + "/v1/completions", self.base_url + "/v1/completions",
json=payload, json=payload,
...@@ -71,6 +76,7 @@ class TestMatchedStop(CustomTestCase): ...@@ -71,6 +76,7 @@ class TestMatchedStop(CustomTestCase):
prompt=MANY_NEW_TOKENS_PROMPT, prompt=MANY_NEW_TOKENS_PROMPT,
max_tokens=1, max_tokens=1,
stop=None, stop=None,
stop_regex=None,
finish_reason=None, finish_reason=None,
matched_stop=None, matched_stop=None,
): ):
...@@ -88,6 +94,9 @@ class TestMatchedStop(CustomTestCase): ...@@ -88,6 +94,9 @@ class TestMatchedStop(CustomTestCase):
if stop is not None: if stop is not None:
chat_payload["stop"] = stop chat_payload["stop"] = stop
if stop_regex is not None:
chat_payload["stop_regex"] = stop_regex
response_chat = requests.post( response_chat = requests.post(
self.base_url + "/v1/chat/completions", self.base_url + "/v1/chat/completions",
json=chat_payload, json=chat_payload,
...@@ -106,6 +115,30 @@ class TestMatchedStop(CustomTestCase): ...@@ -106,6 +115,30 @@ class TestMatchedStop(CustomTestCase):
max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n" max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n"
) )
def test_finish_stop_regex_str(self):
STOP_REGEX_STR = r"and|or"
self.run_completions_generation(
max_tokens=1000,
stop_regex=STOP_REGEX_STR,
finish_reason="stop",
matched_stop=STOP_REGEX_STR,
)
self.run_chat_completions_generation(
max_tokens=1000,
stop_regex=STOP_REGEX_STR,
finish_reason="stop",
matched_stop=STOP_REGEX_STR,
)
# Match a complete sentence
STOP_REGEX_STR_SENTENCE = r"[.!?]\s*$"
self.run_chat_completions_generation(
max_tokens=1000,
stop_regex=STOP_REGEX_STR_SENTENCE,
finish_reason="stop",
matched_stop=STOP_REGEX_STR_SENTENCE,
)
def test_finish_stop_eos(self): def test_finish_stop_eos(self):
llama_format_prompt = """ llama_format_prompt = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|> <|begin_of_text|><|start_header_id|>system<|end_header_id|>
...@@ -136,5 +169,53 @@ class TestMatchedStop(CustomTestCase): ...@@ -136,5 +169,53 @@ class TestMatchedStop(CustomTestCase):
) )
class TestRegexPatternMaxLength(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.regex_str_to_max_len = {
"((ab|cd(e|f){2}){3,5}g|hij)*k": MAX_LEN,
# - '*' → infinite tokens need to be stored
"abc*?k": MAX_LEN,
# - '*?' → infinite tokens still need to be stored even if lazy matching used
"^spec(foo|at)$": 7,
# - '^' and '$' don't add any characters to the max length
# "spec" → 4
# "(foo|at)" → max(3, 2) = 3
# Whole regex = 7
"(a(bca|de(fg|hi){2,3})j){2}kl": 22,
# - Innermost alt: "fg" vs "hi" → 2
# - Repeat {2,3}: max = 3 * 2 = 6
# - Inner group "de(...)": 2 (for "de") + 6 = 8.
# - "bca" or "de(...)" → max(3, 8) = 8
# - Whole group: "a" (1) + group (8) + "j"(1) = 10
# - Repeat {2} → 20
# - Add "kl"(2) → 22
"(foo(bar|baz(qux){1,2}))|(x(yz){5,10})": 21,
# Branch 1:
# "foo"(3) + max("bar"(3), "baz"(3)+"qux"{2} = 3 + 6 = 9) = 3 + 9 = 12
# Branch 2:
# "x"(1) + "yz"{10} = 1 + 20 =21
# Whole regex = max(12, 21) = 21
"(((a|bc){1,3}(d(e|f){2}|gh){2,4})|(ijk|lmp(no|p){3})){5}": 90,
# Branch A:
# (a|bc){1,3} → max = 3 * 2 = 6
# Inside: d(e|f){2} = 1 + 2 * 1 = 3 vs gh = 2 → max = 3
# Repeat {2,4} → 4 * 3 = 12
# Branch A total = 18
# Branch B:
# "ijk"(3) vs "lmp(no|p){3}" = 3 + 3 * max(2, 1) = 3 + 6 = 9 → max = 9
# Branch B total = 9
# Whole outer alt = max(18, 9) = 18
# Repeat {5} → 90
}
def test_get_max_length(self):
for regex_str, max_len in self.regex_str_to_max_len.items():
if max_len == MAX_LEN:
self.assertGreaterEqual(get_max_seq_length(regex_str), MAX_LEN)
else:
self.assertEqual(get_max_seq_length(regex_str), max_len)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment