Unverified Commit aa797d01 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[Test] Merge all constrained decoding tests. (#12633)

parent 7cee07a0
"""
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_all_optional_function_params
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_greeting
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_all_optional_function_params
"""
import json import json
import unittest
import requests import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
def setup_class(cls, backend: str, disable_overlap: bool):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.ebnf_grammar = 'root ::= "test"' # Default grammar
other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
backend,
]
if disable_overlap:
other_args += ["--disable-overlap-schedule"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
class TestEBNFConstrained(CustomTestCase): class TestEBNFConstrainedMinxin:
@classmethod ebnf_grammar = 'root ::= "test"' # Default grammar
def setUpClass(cls):
setup_class(cls, "xgrammar", disable_overlap=False)
@classmethod def _run_decode_ebnf(
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(
self, self,
ebnf, ebnf,
expected_patterns, expected_patterns,
...@@ -110,7 +62,7 @@ class TestEBNFConstrained(CustomTestCase): ...@@ -110,7 +62,7 @@ class TestEBNFConstrained(CustomTestCase):
allowed_patterns = [r"^user@example\.com$"] allowed_patterns = [r"^user@example\.com$"]
prompt = "Generate an email address:" prompt = "Generate an email address:"
self.run_decode( self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar, ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns, expected_patterns=allowed_patterns,
prompt=prompt, prompt=prompt,
...@@ -122,7 +74,7 @@ class TestEBNFConstrained(CustomTestCase): ...@@ -122,7 +74,7 @@ class TestEBNFConstrained(CustomTestCase):
allowed_patterns = [r"^(Hello|Hi|Hey)$"] allowed_patterns = [r"^(Hello|Hi|Hey)$"]
prompt = "Generate a greeting:" prompt = "Generate a greeting:"
self.run_decode( self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar, ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns, expected_patterns=allowed_patterns,
prompt=prompt, prompt=prompt,
...@@ -137,7 +89,7 @@ class TestEBNFConstrained(CustomTestCase): ...@@ -137,7 +89,7 @@ class TestEBNFConstrained(CustomTestCase):
allowed_patterns = [r"^\d{3}$"] allowed_patterns = [r"^\d{3}$"]
prompt = "Generate a three-digit number:" prompt = "Generate a three-digit number:"
self.run_decode( self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar, ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns, expected_patterns=allowed_patterns,
prompt=prompt, prompt=prompt,
...@@ -154,7 +106,7 @@ class TestEBNFConstrained(CustomTestCase): ...@@ -154,7 +106,7 @@ class TestEBNFConstrained(CustomTestCase):
allowed_patterns = [r"^\(\d{3}\) \d{3}-\d{4}$"] allowed_patterns = [r"^\(\d{3}\) \d{3}-\d{4}$"]
prompt = "Generate a phone number:" prompt = "Generate a phone number:"
self.run_decode( self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar, ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns, expected_patterns=allowed_patterns,
prompt=prompt, prompt=prompt,
...@@ -173,7 +125,7 @@ class TestEBNFConstrained(CustomTestCase): ...@@ -173,7 +125,7 @@ class TestEBNFConstrained(CustomTestCase):
allowed_patterns = [r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"] allowed_patterns = [r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"]
prompt = "Generate a date in YYYY-MM-DD format:" prompt = "Generate a date in YYYY-MM-DD format:"
self.run_decode( self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar, ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns, expected_patterns=allowed_patterns,
prompt=prompt, prompt=prompt,
...@@ -188,7 +140,7 @@ class TestEBNFConstrained(CustomTestCase): ...@@ -188,7 +140,7 @@ class TestEBNFConstrained(CustomTestCase):
allowed_patterns = [r"^#[0-9A-F]{6}$"] allowed_patterns = [r"^#[0-9A-F]{6}$"]
prompt = "Generate a hex color code:" prompt = "Generate a hex color code:"
self.run_decode( self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar, ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns, expected_patterns=allowed_patterns,
prompt=prompt, prompt=prompt,
...@@ -212,7 +164,7 @@ class TestEBNFConstrained(CustomTestCase): ...@@ -212,7 +164,7 @@ class TestEBNFConstrained(CustomTestCase):
] ]
prompt = "Generate a simple JSON with name, age, and city:" prompt = "Generate a simple JSON with name, age, and city:"
self.run_decode( self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar, ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns, expected_patterns=allowed_patterns,
prompt=prompt, prompt=prompt,
...@@ -232,7 +184,7 @@ class TestEBNFConstrained(CustomTestCase): ...@@ -232,7 +184,7 @@ class TestEBNFConstrained(CustomTestCase):
] ]
prompt = "Generate a log entry:" prompt = "Generate a log entry:"
self.run_decode( self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar, ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns, expected_patterns=allowed_patterns,
prompt=prompt, prompt=prompt,
...@@ -264,19 +216,9 @@ class TestEBNFConstrained(CustomTestCase): ...@@ -264,19 +216,9 @@ class TestEBNFConstrained(CustomTestCase):
] ]
prompt = "Configure the service with optional settings:" prompt = "Configure the service with optional settings:"
self.run_decode( self._run_decode_ebnf(
ebnf=self.__class__.ebnf_grammar, ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns, expected_patterns=allowed_patterns,
prompt=prompt, prompt=prompt,
n=5, n=5,
) )
class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=False)
if __name__ == "__main__":
unittest.main()
"""
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate
"""
import json import json
import unittest
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import openai import openai
import requests import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
def setup_class(cls, backend: str): class TestJSONConstrainedMixin:
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST json_schema = json.dumps(
cls.base_url = DEFAULT_URL_FOR_TEST
cls.json_schema = json.dumps(
{ {
"type": "object", "type": "object",
"properties": { "properties": {
...@@ -36,31 +18,9 @@ def setup_class(cls, backend: str): ...@@ -36,31 +18,9 @@ def setup_class(cls, backend: str):
} }
) )
other_args = [ def _run_decode_json(
"--max-running-requests", self, json_schema, return_logprob=False, top_logprobs_num=0, n=1
"10", ):
"--grammar-backend",
backend,
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
class TestJSONConstrained(CustomTestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="xgrammar")
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post( response = requests.post(
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
...@@ -95,10 +55,10 @@ class TestJSONConstrained(CustomTestCase): ...@@ -95,10 +55,10 @@ class TestJSONConstrained(CustomTestCase):
self.assertIsInstance(js_obj["population"], int) self.assertIsInstance(js_obj["population"], int)
def test_json_generate(self): def test_json_generate(self):
self.run_decode(json_schema=self.json_schema) self._run_decode_json(json_schema=self.json_schema)
def test_json_invalid(self): def test_json_invalid(self):
self.run_decode(json_schema="INVALID") self._run_decode_json(json_schema="INVALID")
def test_json_openai(self): def test_json_openai(self):
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1") client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
...@@ -134,20 +94,4 @@ class TestJSONConstrained(CustomTestCase): ...@@ -134,20 +94,4 @@ class TestJSONConstrained(CustomTestCase):
json_schemas = [None, None, self.json_schema, self.json_schema] * 10 json_schemas = [None, None, self.json_schema, self.json_schema] * 10
with ThreadPoolExecutor(len(json_schemas)) as executor: with ThreadPoolExecutor(len(json_schemas)) as executor:
list(executor.map(self.run_decode, json_schemas)) list(executor.map(self._run_decode_json, json_schemas))
class TestJSONConstrainedOutlinesBackend(TestJSONConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="outlines")
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="llguidance")
if __name__ == "__main__":
unittest.main()
"""
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_email
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_greeting
"""
import json import json
import unittest
import requests import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
def setup_class(cls, backend: str, disable_overlap: bool):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
backend,
]
if disable_overlap:
other_args += ["--disable-overlap-schedule"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
class TestRegexConstrained(CustomTestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, "xgrammar", disable_overlap=False)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode( class TestRegexConstrainedMixin:
def _run_decode_regex(
self, self,
regex, regex,
prompt, prompt,
...@@ -100,7 +53,7 @@ class TestRegexConstrained(CustomTestCase): ...@@ -100,7 +53,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^user@example\.com$" pattern = r"^user@example\.com$"
prompt = "Generate an email address:" prompt = "Generate an email address:"
self.run_decode( self._run_decode_regex(
regex=pattern, regex=pattern,
prompt=prompt, prompt=prompt,
n=3, n=3,
...@@ -110,7 +63,7 @@ class TestRegexConstrained(CustomTestCase): ...@@ -110,7 +63,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^(Hello|Hi|Hey)$" pattern = r"^(Hello|Hi|Hey)$"
prompt = "Generate a greeting:" prompt = "Generate a greeting:"
self.run_decode( self._run_decode_regex(
regex=pattern, regex=pattern,
prompt=prompt, prompt=prompt,
n=3, n=3,
...@@ -120,7 +73,7 @@ class TestRegexConstrained(CustomTestCase): ...@@ -120,7 +73,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^\d{3}$" pattern = r"^\d{3}$"
prompt = "Generate a three-digit number:" prompt = "Generate a three-digit number:"
self.run_decode( self._run_decode_regex(
regex=pattern, regex=pattern,
prompt=prompt, prompt=prompt,
n=3, n=3,
...@@ -130,7 +83,7 @@ class TestRegexConstrained(CustomTestCase): ...@@ -130,7 +83,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^\(\d{3}\) \d{3}-\d{4}$" pattern = r"^\(\d{3}\) \d{3}-\d{4}$"
prompt = "Generate a phone number:" prompt = "Generate a phone number:"
self.run_decode( self._run_decode_regex(
regex=pattern, regex=pattern,
prompt=prompt, prompt=prompt,
n=3, n=3,
...@@ -140,7 +93,7 @@ class TestRegexConstrained(CustomTestCase): ...@@ -140,7 +93,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$" pattern = r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"
prompt = "Generate a date in YYYY-MM-DD format:" prompt = "Generate a date in YYYY-MM-DD format:"
self.run_decode( self._run_decode_regex(
regex=pattern, regex=pattern,
prompt=prompt, prompt=prompt,
n=3, n=3,
...@@ -150,7 +103,7 @@ class TestRegexConstrained(CustomTestCase): ...@@ -150,7 +103,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^#[0-9A-F]{6}$" pattern = r"^#[0-9A-F]{6}$"
prompt = "Generate a hex color code:" prompt = "Generate a hex color code:"
self.run_decode( self._run_decode_regex(
regex=pattern, regex=pattern,
prompt=prompt, prompt=prompt,
n=3, n=3,
...@@ -160,7 +113,7 @@ class TestRegexConstrained(CustomTestCase): ...@@ -160,7 +113,7 @@ class TestRegexConstrained(CustomTestCase):
pattern = r'^\{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*\}$' pattern = r'^\{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*\}$'
prompt = "Generate a simple JSON with name, age, and city:" prompt = "Generate a simple JSON with name, age, and city:"
self.run_decode( self._run_decode_regex(
regex=pattern, regex=pattern,
prompt=prompt, prompt=prompt,
n=3, n=3,
...@@ -170,18 +123,8 @@ class TestRegexConstrained(CustomTestCase): ...@@ -170,18 +123,8 @@ class TestRegexConstrained(CustomTestCase):
pattern = r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$" pattern = r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$"
prompt = "Generate a log entry:" prompt = "Generate a log entry:"
self.run_decode( self._run_decode_regex(
regex=pattern, regex=pattern,
prompt=prompt, prompt=prompt,
n=3, n=3,
) )
class TestRegexConstrainedLLGuidance(TestRegexConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=True)
if __name__ == "__main__":
unittest.main()
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
Tests for JSON schema constraint functionality used by JsonArrayParser Tests for JSON schema constraint functionality used by JsonArrayParser
""" """
import json
import unittest import unittest
import jsonschema import jsonschema
......
...@@ -16,7 +16,7 @@ class TestFile: ...@@ -16,7 +16,7 @@ class TestFile:
suites = { suites = {
"per-commit-1-gpu": [ "per-commit-1-gpu": [
TestFile("debug_utils/test_tensor_dump_forward_hook.py", 15), TestFile("debug_utils/test_tensor_dump_forward_hook.py", 15),
TestFile("function_call/test_json_schema_constraint.py", 30), TestFile("function_call/test_json_schema_constraint.py", 1),
TestFile("hicache/test_hicache.py", 116), TestFile("hicache/test_hicache.py", 116),
TestFile("hicache/test_hicache_eagle.py", 150), TestFile("hicache/test_hicache_eagle.py", 150),
TestFile("hicache/test_hicache_mla.py", 127), TestFile("hicache/test_hicache_mla.py", 127),
...@@ -46,7 +46,6 @@ suites = { ...@@ -46,7 +46,6 @@ suites = {
TestFile("openai_server/basic/test_serving_completions.py", 10), TestFile("openai_server/basic/test_serving_completions.py", 10),
TestFile("openai_server/basic/test_serving_embedding.py", 10), TestFile("openai_server/basic/test_serving_embedding.py", 10),
TestFile("openai_server/features/test_enable_thinking.py", 70), TestFile("openai_server/features/test_enable_thinking.py", 70),
TestFile("openai_server/features/test_json_constrained.py", 120),
TestFile("openai_server/features/test_json_mode.py", 120), TestFile("openai_server/features/test_json_mode.py", 120),
TestFile("openai_server/features/test_openai_server_ebnf.py", 20), TestFile("openai_server/features/test_openai_server_ebnf.py", 20),
TestFile("openai_server/features/test_openai_server_hidden_states.py", 240), TestFile("openai_server/features/test_openai_server_hidden_states.py", 240),
...@@ -74,7 +73,7 @@ suites = { ...@@ -74,7 +73,7 @@ suites = {
TestFile("test_eagle_infer_a.py", 370), TestFile("test_eagle_infer_a.py", 370),
TestFile("test_eagle_infer_b.py", 500), TestFile("test_eagle_infer_b.py", 500),
TestFile("test_eagle_infer_beta.py", 90), TestFile("test_eagle_infer_beta.py", 90),
TestFile("test_ebnf_constrained.py", 80), TestFile("test_constrained_decoding.py", 120),
TestFile("test_eval_fp8_accuracy.py", 303), TestFile("test_eval_fp8_accuracy.py", 303),
TestFile("test_fa3.py", 420), TestFile("test_fa3.py", 420),
TestFile("test_flashmla.py", 230), TestFile("test_flashmla.py", 230),
...@@ -305,7 +304,7 @@ suites = { ...@@ -305,7 +304,7 @@ suites = {
TestFile("test_int4_kernel.py"), TestFile("test_int4_kernel.py"),
TestFile("test_int8_kernel.py"), TestFile("test_int8_kernel.py"),
TestFile("test_intel_amx_attention_backend.py"), TestFile("test_intel_amx_attention_backend.py"),
TestFile("test_json_constrained.py"), TestFile("test_constrained_decoding.py"),
TestFile("test_json_mode.py"), TestFile("test_json_mode.py"),
TestFile("test_kv_events.py"), TestFile("test_kv_events.py"),
TestFile("test_large_max_new_tokens.py"), TestFile("test_large_max_new_tokens.py"),
...@@ -369,7 +368,7 @@ suites = { ...@@ -369,7 +368,7 @@ suites = {
# NOTE: please sort the test cases alphabetically by the test file name # NOTE: please sort the test cases alphabetically by the test file name
suite_amd = { suite_amd = {
"per-commit-amd": [ "per-commit-amd": [
TestFile("function_call/test_json_schema_constraint.py", 30), TestFile("function_call/test_json_schema_constraint.py", 1),
TestFile("hicache/test_hicache.py", 116), TestFile("hicache/test_hicache.py", 116),
TestFile("hicache/test_hicache_mla.py", 127), TestFile("hicache/test_hicache_mla.py", 127),
TestFile("hicache/test_hicache_storage.py", 127), TestFile("hicache/test_hicache_storage.py", 127),
...@@ -390,7 +389,6 @@ suite_amd = { ...@@ -390,7 +389,6 @@ suite_amd = {
TestFile("openai_server/basic/test_serving_completions.py", 10), TestFile("openai_server/basic/test_serving_completions.py", 10),
TestFile("openai_server/basic/test_serving_embedding.py", 10), TestFile("openai_server/basic/test_serving_embedding.py", 10),
TestFile("openai_server/features/test_enable_thinking.py", 70), TestFile("openai_server/features/test_enable_thinking.py", 70),
TestFile("openai_server/features/test_json_constrained.py", 120),
TestFile("openai_server/features/test_json_mode.py", 120), TestFile("openai_server/features/test_json_mode.py", 120),
TestFile("openai_server/features/test_openai_server_ebnf.py", 20), TestFile("openai_server/features/test_openai_server_ebnf.py", 20),
TestFile("openai_server/features/test_reasoning_content.py", 89), TestFile("openai_server/features/test_reasoning_content.py", 89),
...@@ -406,7 +404,6 @@ suite_amd = { ...@@ -406,7 +404,6 @@ suite_amd = {
TestFile("test_abort.py", 51), TestFile("test_abort.py", 51),
TestFile("test_chunked_prefill.py", 410), TestFile("test_chunked_prefill.py", 410),
TestFile("test_create_kvindices.py", 2), TestFile("test_create_kvindices.py", 2),
TestFile("test_ebnf_constrained.py", 80),
TestFile("test_eval_fp8_accuracy.py", 303), TestFile("test_eval_fp8_accuracy.py", 303),
TestFile("test_function_call_parser.py", 10), TestFile("test_function_call_parser.py", 10),
TestFile("test_fused_moe.py", 80), TestFile("test_fused_moe.py", 80),
...@@ -423,7 +420,7 @@ suite_amd = { ...@@ -423,7 +420,7 @@ suite_amd = {
TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_pytorch_sampling_backend.py", 66),
TestFile("test_radix_attention.py", 105), TestFile("test_radix_attention.py", 105),
TestFile("test_reasoning_parser.py", 5), TestFile("test_reasoning_parser.py", 5),
TestFile("test_regex_constrained.py", 64), TestFile("test_constrained_decoding.py", 120),
TestFile("test_retract_decode.py", 450), TestFile("test_retract_decode.py", 450),
TestFile("test_rope_rocm.py", 3), TestFile("test_rope_rocm.py", 3),
TestFile("test_server_args.py", 1), TestFile("test_server_args.py", 1),
......
import unittest
from sglang.srt.utils import kill_process_tree
from sglang.test.kits.ebnf_constrained_kit import TestEBNFConstrainedMinxin
from sglang.test.kits.json_constrained_kit import TestJSONConstrainedMixin
from sglang.test.kits.regex_constrained_kit import TestRegexConstrainedMixin
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class ServerWithGrammar(CustomTestCase):
backend = "xgrammar"
disable_overlap = False
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
launch_args = [
"--max-running-requests",
"10",
"--grammar-backend",
cls.backend,
]
if cls.disable_overlap:
launch_args += ["--disable-overlap-schedule"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=launch_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
class TestXGrammarBackend(
ServerWithGrammar,
TestJSONConstrainedMixin,
TestEBNFConstrainedMinxin,
TestRegexConstrainedMixin,
):
backend = "xgrammar"
class TestOutlinesBackend(ServerWithGrammar, TestJSONConstrainedMixin):
backend = "outlines"
class TestLLGuidanceBackend(
ServerWithGrammar,
TestJSONConstrainedMixin,
TestEBNFConstrainedMinxin,
TestRegexConstrainedMixin,
):
backend = "llguidance"
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment