Unverified Commit 54fb1c80 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up unit tests (#1020)

parent b68c4c07
import json
import os
import sys
import unittest
import requests
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
class TestSRTEndpoint(unittest.TestCase):
class TestSkipTokenizerInit(unittest.TestCase):
@classmethod
def setUpClass(cls):
......@@ -26,9 +21,7 @@ class TestSRTEndpoint(unittest.TestCase):
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_decode(
self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
):
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post(
self.base_url + "/generate",
json={
......@@ -50,7 +43,6 @@ class TestSRTEndpoint(unittest.TestCase):
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
......@@ -65,13 +57,11 @@ class TestSRTEndpoint(unittest.TestCase):
def test_logprob(self):
for top_logprobs_num in [0, 3]:
for return_text in [False, False]:
self.run_decode(
return_logprob=True,
top_logprobs_num=top_logprobs_num,
return_text=return_text,
)
self.run_decode(
return_logprob=True,
top_logprobs_num=top_logprobs_num,
)
if __name__ == "__main__":
unittest.main(warnings="ignore")
unittest.main()
......@@ -4,7 +4,6 @@ import unittest
import requests
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
......@@ -59,4 +58,4 @@ class TestSRTEndpoint(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
unittest.main()
......@@ -34,9 +34,4 @@ class TestAccuracy(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestAccuracy()
# t.setUpClass()
# t.test_mmlu()
# t.tearDownClass()
unittest.main()
......@@ -113,9 +113,4 @@ class TestOpenAIVisionServer(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestOpenAIVisionServer()
# t.setUpClass()
# t.test_chat_completion()
# t.tearDownClass()
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment