Unverified Commit 6f221d4c authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Fix unit tests for the frontend language part (#872)

parent aba6f51f
name: lint name: Lint
on: [push, pull_request] on: [push, pull_request]
......
...@@ -99,7 +99,6 @@ class SglSamplingParams: ...@@ -99,7 +99,6 @@ class SglSamplingParams:
"stop": self.stop or None, "stop": self.stop or None,
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p, "top_p": self.top_p,
"top_k": self.top_k,
"frequency_penalty": self.frequency_penalty, "frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty, "presence_penalty": self.presence_penalty,
} }
......
...@@ -479,6 +479,9 @@ class Runtime: ...@@ -479,6 +479,9 @@ class Runtime:
parent.wait(timeout=5) parent.wait(timeout=5)
self.pid = None self.pid = None
def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)
def get_tokenizer(self): def get_tokenizer(self):
return get_tokenizer( return get_tokenizer(
self.server_args.tokenizer_path, self.server_args.tokenizer_path,
......
...@@ -113,15 +113,14 @@ def test_decode_json_regex(): ...@@ -113,15 +113,14 @@ def test_decode_json_regex():
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT + ",") + "\n" s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT + ",") + "\n"
s += ' "country": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n" s += ' "country": ' + sgl.gen(regex=REGEX_STRING) + "\n"
s += ' "timezone": ' + sgl.gen(regex=REGEX_STRING) + "\n"
s += "}" s += "}"
ret = decode_json.run() ret = decode_json.run(temperature=0.0)
try: try:
js_obj = json.loads(ret["json_output"]) js_obj = json.loads(ret["json_output"])
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
print(ret["json_output"]) print("JSONDecodeError", ret["json_output"])
raise raise
assert isinstance(js_obj["name"], str) assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int) assert isinstance(js_obj["population"], int)
...@@ -141,8 +140,12 @@ def test_decode_json(): ...@@ -141,8 +140,12 @@ def test_decode_json():
s += ' "timezone": ' + sgl.gen(dtype=str) + "\n" s += ' "timezone": ' + sgl.gen(dtype=str) + "\n"
s += "}" s += "}"
ret = decode_json.run() ret = decode_json.run(max_new_tokens=64)
js_obj = json.loads(ret["json_output"]) try:
js_obj = json.loads(ret["json_output"])
except json.decoder.JSONDecodeError:
print("JSONDecodeError", ret["json_output"])
raise
assert isinstance(js_obj["name"], str) assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int) assert isinstance(js_obj["population"], int)
......
...@@ -7,6 +7,10 @@ import unittest ...@@ -7,6 +7,10 @@ import unittest
from sglang.utils import run_with_timeout from sglang.utils import run_with_timeout
suites = {
"minimal": ["test_openai_backend.py", "test_srt_backend.py"],
}
def run_unittest_files(files, args): def run_unittest_files(files, args):
for filename in files: for filename in files:
...@@ -45,9 +49,19 @@ if __name__ == "__main__": ...@@ -45,9 +49,19 @@ if __name__ == "__main__":
default=1000, default=1000,
help="The time limit for running one file in seconds.", help="The time limit for running one file in seconds.",
) )
arg_parser.add_argument(
"--suite",
type=str,
default=list(suites.keys())[0],
choices=list(suites.keys()) + ["all"],
help="The suite to run",
)
args = arg_parser.parse_args() args = arg_parser.parse_args()
files = glob.glob("**/test_*.py", recursive=True) if args.suite == "all":
files = glob.glob("**/test_*.py", recursive=True)
else:
files = suites[args.suite]
tic = time.time() tic = time.time()
success = run_unittest_files(files, args) success = run_unittest_files(files, args)
......
...@@ -7,14 +7,11 @@ from sglang.test.test_programs import test_mt_bench, test_stream ...@@ -7,14 +7,11 @@ from sglang.test.test_programs import test_mt_bench, test_stream
class TestAnthropicBackend(unittest.TestCase): class TestAnthropicBackend(unittest.TestCase):
backend = None backend = None
chat_backend = None
def setUp(self): @classmethod
cls = type(self) def setUpClass(cls):
cls.backend = Anthropic("claude-3-haiku-20240307")
if cls.backend is None: set_default_backend(cls.backend)
cls.backend = Anthropic("claude-3-haiku-20240307")
set_default_backend(cls.backend)
def test_mt_bench(self): def test_mt_bench(self):
test_mt_bench() test_mt_bench()
...@@ -30,5 +27,5 @@ if __name__ == "__main__": ...@@ -30,5 +27,5 @@ if __name__ == "__main__":
# global_config.verbosity = 2 # global_config.verbosity = 2
# t = TestAnthropicBackend() # t = TestAnthropicBackend()
# t.setUp() # t.setUpClass()
# t.test_mt_bench() # t.test_mt_bench()
"""
Usage:
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000
python3 test_bind_cache.py
"""
import unittest import unittest
import sglang as sgl import sglang as sgl
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
class TestBind(unittest.TestCase): class TestBind(unittest.TestCase):
backend = None backend = None
def setUp(self): @classmethod
cls = type(self) def setUpClass(cls):
cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3-8B-Instruct")
sgl.set_default_backend(cls.backend)
if cls.backend is None: @classmethod
cls.backend = RuntimeEndpoint(base_url="http://localhost:30000") def tearDownClass(cls):
cls.backend.shutdown()
def test_bind(self): def test_bind(self):
@sgl.function @sgl.function
...@@ -54,5 +50,5 @@ if __name__ == "__main__": ...@@ -54,5 +50,5 @@ if __name__ == "__main__":
unittest.main(warnings="ignore") unittest.main(warnings="ignore")
# t = TestBind() # t = TestBind()
# t.setUp() # t.setUpClass()
# t.test_cache() # t.test_cache()
...@@ -6,15 +6,12 @@ from sglang.test.test_programs import test_mt_bench, test_stream ...@@ -6,15 +6,12 @@ from sglang.test.test_programs import test_mt_bench, test_stream
class TestAnthropicBackend(unittest.TestCase): class TestAnthropicBackend(unittest.TestCase):
backend = None
chat_backend = None chat_backend = None
def setUp(self): @classmethod
cls = type(self) def setUpClass(cls):
cls.chat_backend = LiteLLM("gpt-3.5-turbo")
if cls.backend is None: set_default_backend(cls.chat_backend)
cls.backend = LiteLLM("gpt-3.5-turbo")
set_default_backend(cls.backend)
def test_mt_bench(self): def test_mt_bench(self):
test_mt_bench() test_mt_bench()
......
...@@ -20,20 +20,18 @@ from sglang.test.test_programs import ( ...@@ -20,20 +20,18 @@ from sglang.test.test_programs import (
class TestOpenAIBackend(unittest.TestCase): class TestOpenAIBackend(unittest.TestCase):
backend = None instruct_backend = None
chat_backend = None chat_backend = None
chat_vision_backend = None chat_vision_backend = None
def setUp(self): @classmethod
cls = type(self) def setUpClass(cls):
cls.instruct_backend = OpenAI("gpt-3.5-turbo-instruct")
if cls.backend is None: cls.chat_backend = OpenAI("gpt-3.5-turbo")
cls.backend = OpenAI("gpt-3.5-turbo-instruct") cls.chat_vision_backend = OpenAI("gpt-4-turbo")
cls.chat_backend = OpenAI("gpt-3.5-turbo")
cls.chat_vision_backend = OpenAI("gpt-4-turbo")
def test_few_shot_qa(self): def test_few_shot_qa(self):
set_default_backend(self.backend) set_default_backend(self.instruct_backend)
test_few_shot_qa() test_few_shot_qa()
def test_mt_bench(self): def test_mt_bench(self):
...@@ -41,35 +39,35 @@ class TestOpenAIBackend(unittest.TestCase): ...@@ -41,35 +39,35 @@ class TestOpenAIBackend(unittest.TestCase):
test_mt_bench() test_mt_bench()
def test_select(self): def test_select(self):
set_default_backend(self.backend) set_default_backend(self.instruct_backend)
test_select(check_answer=True) test_select(check_answer=True)
def test_decode_int(self): def test_decode_int(self):
set_default_backend(self.backend) set_default_backend(self.instruct_backend)
test_decode_int() test_decode_int()
def test_decode_json(self): def test_decode_json(self):
set_default_backend(self.backend) set_default_backend(self.instruct_backend)
test_decode_json() test_decode_json()
def test_expert_answer(self): def test_expert_answer(self):
set_default_backend(self.backend) set_default_backend(self.instruct_backend)
test_expert_answer() test_expert_answer()
def test_tool_use(self): def test_tool_use(self):
set_default_backend(self.backend) set_default_backend(self.instruct_backend)
test_tool_use() test_tool_use()
def test_react(self): def test_react(self):
set_default_backend(self.backend) set_default_backend(self.instruct_backend)
test_react() test_react()
def test_parallel_decoding(self): def test_parallel_decoding(self):
set_default_backend(self.backend) set_default_backend(self.instruct_backend)
test_parallel_decoding() test_parallel_decoding()
def test_parallel_encoding(self): def test_parallel_encoding(self):
set_default_backend(self.backend) set_default_backend(self.instruct_backend)
test_parallel_encoding() test_parallel_encoding()
def test_image_qa(self): def test_image_qa(self):
...@@ -77,11 +75,11 @@ class TestOpenAIBackend(unittest.TestCase): ...@@ -77,11 +75,11 @@ class TestOpenAIBackend(unittest.TestCase):
test_image_qa() test_image_qa()
def test_stream(self): def test_stream(self):
set_default_backend(self.backend) set_default_backend(self.instruct_backend)
test_stream() test_stream()
def test_completion_speculative(self): def test_completion_speculative(self):
set_default_backend(self.backend) set_default_backend(self.instruct_backend)
test_completion_speculative() test_completion_speculative()
def test_chat_completion_speculative(self): def test_chat_completion_speculative(self):
...@@ -96,5 +94,5 @@ if __name__ == "__main__": ...@@ -96,5 +94,5 @@ if __name__ == "__main__":
# global_config.verbosity = 2 # global_config.verbosity = 2
# t = TestOpenAIBackend() # t = TestOpenAIBackend()
# t.setUp() # t.setUpClass()
# t.test_chat_completion_speculative() # t.test_stream()
"""
Usage:
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000
python3 test_srt_backend.py
"""
import json import json
import unittest import unittest
...@@ -15,8 +9,6 @@ from sglang.test.test_programs import ( ...@@ -15,8 +9,6 @@ from sglang.test.test_programs import (
test_few_shot_qa, test_few_shot_qa,
test_mt_bench, test_mt_bench,
test_parallel_decoding, test_parallel_decoding,
test_parallel_encoding,
test_react,
test_regex, test_regex,
test_select, test_select,
test_stream, test_stream,
...@@ -27,12 +19,14 @@ from sglang.test.test_programs import ( ...@@ -27,12 +19,14 @@ from sglang.test.test_programs import (
class TestSRTBackend(unittest.TestCase): class TestSRTBackend(unittest.TestCase):
backend = None backend = None
def setUp(self): @classmethod
cls = type(self) def setUpClass(cls):
cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3-8B-Instruct")
sgl.set_default_backend(cls.backend)
if cls.backend is None: @classmethod
cls.backend = sgl.RuntimeEndpoint(base_url="http://localhost:30000") def tearDownClass(cls):
sgl.set_default_backend(cls.backend) cls.backend.shutdown()
def test_few_shot_qa(self): def test_few_shot_qa(self):
test_few_shot_qa() test_few_shot_qa()
...@@ -64,9 +58,6 @@ class TestSRTBackend(unittest.TestCase): ...@@ -64,9 +58,6 @@ class TestSRTBackend(unittest.TestCase):
def test_regex(self): def test_regex(self):
test_regex() test_regex()
# def test_parallel_encoding(self):
# test_parallel_encoding(check_answer=False)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(warnings="ignore") unittest.main(warnings="ignore")
...@@ -75,5 +66,6 @@ if __name__ == "__main__": ...@@ -75,5 +66,6 @@ if __name__ == "__main__":
# global_config.verbosity = 2 # global_config.verbosity = 2
# t = TestSRTBackend() # t = TestSRTBackend()
# t.setUp() # t.setUpClass()
# t.test_regex() # t.test_few_shot_qa()
# t.tearDownClass()
...@@ -17,13 +17,11 @@ class TestVertexAIBackend(unittest.TestCase): ...@@ -17,13 +17,11 @@ class TestVertexAIBackend(unittest.TestCase):
chat_backend = None chat_backend = None
chat_vision_backend = None chat_vision_backend = None
def setUp(self): @classmethod
cls = type(self) def setUpClass(cls):
cls.backend = VertexAI("gemini-pro")
if cls.backend is None: cls.chat_backend = VertexAI("gemini-pro")
cls.backend = VertexAI("gemini-pro") cls.chat_vision_backend = VertexAI("gemini-pro-vision")
cls.chat_backend = VertexAI("gemini-pro")
cls.chat_vision_backend = VertexAI("gemini-pro-vision")
def test_few_shot_qa(self): def test_few_shot_qa(self):
set_default_backend(self.backend) set_default_backend(self.backend)
...@@ -61,5 +59,5 @@ if __name__ == "__main__": ...@@ -61,5 +59,5 @@ if __name__ == "__main__":
# global_config.verbosity = 2 # global_config.verbosity = 2
# t = TestVertexAIBackend() # t = TestVertexAIBackend()
# t.setUp() # t.setUpClass()
# t.test_stream() # t.test_stream()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment