Unverified Commit aa797d01 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[Test] Merge all constrained decoding tests. (#12633)

parent 7cee07a0
"""
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_all_optional_function_params
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_greeting
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_all_optional_function_params
"""
import json
import unittest
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
def setup_class(cls, backend: str, disable_overlap: bool):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.ebnf_grammar = 'root ::= "test"' # Default grammar
other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
backend,
]
if disable_overlap:
other_args += ["--disable-overlap-schedule"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
class TestEBNFConstrainedMinxin:
ebnf_grammar = 'root ::= "test"' # Default grammar
class TestEBNFConstrained(CustomTestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, "xgrammar", disable_overlap=False)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(
def _run_decode_ebnf(
self,
ebnf,
expected_patterns,
......@@ -110,7 +62,7 @@ class TestEBNFConstrained(CustomTestCase):
allowed_patterns = [r"^user@example\.com$"]
prompt = "Generate an email address:"
self.run_decode(
self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
......@@ -122,7 +74,7 @@ class TestEBNFConstrained(CustomTestCase):
allowed_patterns = [r"^(Hello|Hi|Hey)$"]
prompt = "Generate a greeting:"
self.run_decode(
self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
......@@ -137,7 +89,7 @@ class TestEBNFConstrained(CustomTestCase):
allowed_patterns = [r"^\d{3}$"]
prompt = "Generate a three-digit number:"
self.run_decode(
self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
......@@ -154,7 +106,7 @@ class TestEBNFConstrained(CustomTestCase):
allowed_patterns = [r"^\(\d{3}\) \d{3}-\d{4}$"]
prompt = "Generate a phone number:"
self.run_decode(
self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
......@@ -173,7 +125,7 @@ class TestEBNFConstrained(CustomTestCase):
allowed_patterns = [r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"]
prompt = "Generate a date in YYYY-MM-DD format:"
self.run_decode(
self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
......@@ -188,7 +140,7 @@ class TestEBNFConstrained(CustomTestCase):
allowed_patterns = [r"^#[0-9A-F]{6}$"]
prompt = "Generate a hex color code:"
self.run_decode(
self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
......@@ -212,7 +164,7 @@ class TestEBNFConstrained(CustomTestCase):
]
prompt = "Generate a simple JSON with name, age, and city:"
self.run_decode(
self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
......@@ -232,7 +184,7 @@ class TestEBNFConstrained(CustomTestCase):
]
prompt = "Generate a log entry:"
self.run_decode(
self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
......@@ -264,19 +216,9 @@ class TestEBNFConstrained(CustomTestCase):
]
prompt = "Configure the service with optional settings:"
self.run_decode(
self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=5,
)
class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=False)
if __name__ == "__main__":
unittest.main()
"""
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate
"""
import json
import unittest
from concurrent.futures import ThreadPoolExecutor
import openai
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
def setup_class(cls, backend: str):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.json_schema = json.dumps(
class TestJSONConstrainedMixin:
json_schema = json.dumps(
{
"type": "object",
"properties": {
......@@ -36,31 +18,9 @@ def setup_class(cls, backend: str):
}
)
other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
backend,
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
class TestJSONConstrained(CustomTestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="xgrammar")
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
def _run_decode_json(
self, json_schema, return_logprob=False, top_logprobs_num=0, n=1
):
response = requests.post(
self.base_url + "/generate",
json={
......@@ -95,10 +55,10 @@ class TestJSONConstrained(CustomTestCase):
self.assertIsInstance(js_obj["population"], int)
def test_json_generate(self):
self.run_decode(json_schema=self.json_schema)
self._run_decode_json(json_schema=self.json_schema)
def test_json_invalid(self):
self.run_decode(json_schema="INVALID")
self._run_decode_json(json_schema="INVALID")
def test_json_openai(self):
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
......@@ -134,20 +94,4 @@ class TestJSONConstrained(CustomTestCase):
json_schemas = [None, None, self.json_schema, self.json_schema] * 10
with ThreadPoolExecutor(len(json_schemas)) as executor:
list(executor.map(self.run_decode, json_schemas))
class TestJSONConstrainedOutlinesBackend(TestJSONConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="outlines")
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="llguidance")
if __name__ == "__main__":
unittest.main()
list(executor.map(self._run_decode_json, json_schemas))
"""
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_email
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_greeting
"""
import json
import unittest
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
def setup_class(cls, backend: str, disable_overlap: bool):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
backend,
]
if disable_overlap:
other_args += ["--disable-overlap-schedule"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
class TestRegexConstrained(CustomTestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, "xgrammar", disable_overlap=False)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(
class TestRegexConstrainedMixin:
def _run_decode_regex(
self,
regex,
prompt,
......@@ -100,7 +53,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^user@example\.com$"
prompt = "Generate an email address:"
self.run_decode(
self._run_decode_regex(
regex=pattern,
prompt=prompt,
n=3,
......@@ -110,7 +63,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^(Hello|Hi|Hey)$"
prompt = "Generate a greeting:"
self.run_decode(
self._run_decode_regex(
regex=pattern,
prompt=prompt,
n=3,
......@@ -120,7 +73,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^\d{3}$"
prompt = "Generate a three-digit number:"
self.run_decode(
self._run_decode_regex(
regex=pattern,
prompt=prompt,
n=3,
......@@ -130,7 +83,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^\(\d{3}\) \d{3}-\d{4}$"
prompt = "Generate a phone number:"
self.run_decode(
self._run_decode_regex(
regex=pattern,
prompt=prompt,
n=3,
......@@ -140,7 +93,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"
prompt = "Generate a date in YYYY-MM-DD format:"
self.run_decode(
self._run_decode_regex(
regex=pattern,
prompt=prompt,
n=3,
......@@ -150,7 +103,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^#[0-9A-F]{6}$"
prompt = "Generate a hex color code:"
self.run_decode(
self._run_decode_regex(
regex=pattern,
prompt=prompt,
n=3,
......@@ -160,7 +113,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r'^\{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*\}$'
prompt = "Generate a simple JSON with name, age, and city:"
self.run_decode(
self._run_decode_regex(
regex=pattern,
prompt=prompt,
n=3,
......@@ -170,18 +123,8 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$"
prompt = "Generate a log entry:"
self.run_decode(
self._run_decode_regex(
regex=pattern,
prompt=prompt,
n=3,
)
class TestRegexConstrainedLLGuidance(TestRegexConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=True)
if __name__ == "__main__":
unittest.main()
......@@ -2,7 +2,6 @@
Tests for JSON schema constraint functionality used by JsonArrayParser
"""
import json
import unittest
import jsonschema
......
......@@ -16,7 +16,7 @@ class TestFile:
suites = {
"per-commit-1-gpu": [
TestFile("debug_utils/test_tensor_dump_forward_hook.py", 15),
TestFile("function_call/test_json_schema_constraint.py", 30),
TestFile("function_call/test_json_schema_constraint.py", 1),
TestFile("hicache/test_hicache.py", 116),
TestFile("hicache/test_hicache_eagle.py", 150),
TestFile("hicache/test_hicache_mla.py", 127),
......@@ -46,7 +46,6 @@ suites = {
TestFile("openai_server/basic/test_serving_completions.py", 10),
TestFile("openai_server/basic/test_serving_embedding.py", 10),
TestFile("openai_server/features/test_enable_thinking.py", 70),
TestFile("openai_server/features/test_json_constrained.py", 120),
TestFile("openai_server/features/test_json_mode.py", 120),
TestFile("openai_server/features/test_openai_server_ebnf.py", 20),
TestFile("openai_server/features/test_openai_server_hidden_states.py", 240),
......@@ -74,7 +73,7 @@ suites = {
TestFile("test_eagle_infer_a.py", 370),
TestFile("test_eagle_infer_b.py", 500),
TestFile("test_eagle_infer_beta.py", 90),
TestFile("test_ebnf_constrained.py", 80),
TestFile("test_constrained_decoding.py", 120),
TestFile("test_eval_fp8_accuracy.py", 303),
TestFile("test_fa3.py", 420),
TestFile("test_flashmla.py", 230),
......@@ -305,7 +304,7 @@ suites = {
TestFile("test_int4_kernel.py"),
TestFile("test_int8_kernel.py"),
TestFile("test_intel_amx_attention_backend.py"),
TestFile("test_json_constrained.py"),
TestFile("test_constrained_decoding.py"),
TestFile("test_json_mode.py"),
TestFile("test_kv_events.py"),
TestFile("test_large_max_new_tokens.py"),
......@@ -369,7 +368,7 @@ suites = {
# NOTE: please sort the test cases alphabetically by the test file name
suite_amd = {
"per-commit-amd": [
TestFile("function_call/test_json_schema_constraint.py", 30),
TestFile("function_call/test_json_schema_constraint.py", 1),
TestFile("hicache/test_hicache.py", 116),
TestFile("hicache/test_hicache_mla.py", 127),
TestFile("hicache/test_hicache_storage.py", 127),
......@@ -390,7 +389,6 @@ suite_amd = {
TestFile("openai_server/basic/test_serving_completions.py", 10),
TestFile("openai_server/basic/test_serving_embedding.py", 10),
TestFile("openai_server/features/test_enable_thinking.py", 70),
TestFile("openai_server/features/test_json_constrained.py", 120),
TestFile("openai_server/features/test_json_mode.py", 120),
TestFile("openai_server/features/test_openai_server_ebnf.py", 20),
TestFile("openai_server/features/test_reasoning_content.py", 89),
......@@ -406,7 +404,6 @@ suite_amd = {
TestFile("test_abort.py", 51),
TestFile("test_chunked_prefill.py", 410),
TestFile("test_create_kvindices.py", 2),
TestFile("test_ebnf_constrained.py", 80),
TestFile("test_eval_fp8_accuracy.py", 303),
TestFile("test_function_call_parser.py", 10),
TestFile("test_fused_moe.py", 80),
......@@ -423,7 +420,7 @@ suite_amd = {
TestFile("test_pytorch_sampling_backend.py", 66),
TestFile("test_radix_attention.py", 105),
TestFile("test_reasoning_parser.py", 5),
TestFile("test_regex_constrained.py", 64),
TestFile("test_constrained_decoding.py", 120),
TestFile("test_retract_decode.py", 450),
TestFile("test_rope_rocm.py", 3),
TestFile("test_server_args.py", 1),
......
import unittest
from sglang.srt.utils import kill_process_tree
from sglang.test.kits.ebnf_constrained_kit import TestEBNFConstrainedMinxin
from sglang.test.kits.json_constrained_kit import TestJSONConstrainedMixin
from sglang.test.kits.regex_constrained_kit import TestRegexConstrainedMixin
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class ServerWithGrammar(CustomTestCase):
backend = "xgrammar"
disable_overlap = False
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
launch_args = [
"--max-running-requests",
"10",
"--grammar-backend",
cls.backend,
]
if cls.disable_overlap:
launch_args += ["--disable-overlap-schedule"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=launch_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
class TestXGrammarBackend(
ServerWithGrammar,
TestJSONConstrainedMixin,
TestEBNFConstrainedMinxin,
TestRegexConstrainedMixin,
):
backend = "xgrammar"
class TestOutlinesBackend(ServerWithGrammar, TestJSONConstrainedMixin):
backend = "outlines"
class TestLLGuidanceBackend(
ServerWithGrammar,
TestJSONConstrainedMixin,
TestEBNFConstrainedMinxin,
TestRegexConstrainedMixin,
):
backend = "llguidance"
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment