Unverified Commit 72b6ea88 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Make scripts under `/test/srt` as unit tests (#875)

parent e4d3333c
...@@ -20,8 +20,6 @@ concurrency: ...@@ -20,8 +20,6 @@ concurrency:
jobs: jobs:
unit-test: unit-test:
runs-on: self-hosted runs-on: self-hosted
env:
CUDA_VISIBLE_DEVICES: 6
steps: steps:
- name: Checkout code - name: Checkout code
...@@ -30,6 +28,7 @@ jobs: ...@@ -30,6 +28,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
cd /data/zhyncs/venv && source ./bin/activate && cd - cd /data/zhyncs/venv && source ./bin/activate && cd -
pip cache purge pip cache purge
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[all]" pip install -e "python[all]"
...@@ -39,6 +38,14 @@ jobs: ...@@ -39,6 +38,14 @@ jobs:
- name: Test OpenAI Backend - name: Test OpenAI Backend
run: | run: |
cd /data/zhyncs/venv && source ./bin/activate && cd - cd /data/zhyncs/venv && source ./bin/activate && cd -
cd test/lang
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
cd test/lang
python3 test_openai_backend.py python3 test_openai_backend.py
- name: Test SRT Backend
run: |
cd /data/zhyncs/venv && source ./bin/activate && cd -
cd test/lang
python3 test_srt_backend.py
...@@ -73,6 +73,7 @@ from sglang.srt.utils import ( ...@@ -73,6 +73,7 @@ from sglang.srt.utils import (
assert_pkg_version, assert_pkg_version,
enable_show_time_cost, enable_show_time_cost,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
kill_child_process,
set_ulimit, set_ulimit,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -467,16 +468,7 @@ class Runtime: ...@@ -467,16 +468,7 @@ class Runtime:
def shutdown(self): def shutdown(self):
if self.pid is not None: if self.pid is not None:
try: kill_child_process(self.pid)
parent = psutil.Process(self.pid)
except psutil.NoSuchProcess:
return
children = parent.children(recursive=True)
for child in children:
child.kill()
psutil.wait_procs(children, timeout=5)
parent.kill()
parent.wait(timeout=5)
self.pid = None self.pid = None
def cache_prefix(self, prefix: str): def cache_prefix(self, prefix: str):
......
...@@ -366,6 +366,26 @@ def kill_parent_process(): ...@@ -366,6 +366,26 @@ def kill_parent_process():
os.kill(parent_process.pid, 9) os.kill(parent_process.pid, 9)
def kill_child_process(pid, including_parent=True):
try:
parent = psutil.Process(pid)
except psutil.NoSuchProcess:
return
children = parent.children(recursive=True)
for child in children:
try:
child.kill()
except psutil.NoSuchProcess:
pass
if including_parent:
try:
parent.kill()
except psutil.NoSuchProcess:
pass
def monkey_patch_vllm_p2p_access_check(gpu_id: int): def monkey_patch_vllm_p2p_access_check(gpu_id: int):
""" """
Monkey patch the slow p2p access check in vllm. Monkey patch the slow p2p access check in vllm.
......
...@@ -105,15 +105,14 @@ def test_decode_json_regex(): ...@@ -105,15 +105,14 @@ def test_decode_json_regex():
def decode_json(s): def decode_json(s):
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
s += "Generate a JSON object to describe the basic information of a city.\n" s += "Generate a JSON object to describe the basic city information of Paris.\n"
with s.var_scope("json_output"): with s.var_scope("json_output"):
s += "{\n" s += "{\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n" s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT + ",") + "\n" s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
s += ' "country": ' + sgl.gen(regex=REGEX_STRING) + "\n"
s += "}" s += "}"
ret = decode_json.run(temperature=0.0) ret = decode_json.run(temperature=0.0)
...@@ -129,7 +128,7 @@ def test_decode_json_regex(): ...@@ -129,7 +128,7 @@ def test_decode_json_regex():
def test_decode_json(): def test_decode_json():
@sgl.function @sgl.function
def decode_json(s): def decode_json(s):
s += "Generate a JSON object to describe the basic information of a city.\n" s += "Generate a JSON object to describe the basic city information of Paris.\n"
with s.var_scope("json_output"): with s.var_scope("json_output"):
s += "{\n" s += "{\n"
...@@ -264,6 +263,7 @@ def test_parallel_decoding(): ...@@ -264,6 +263,7 @@ def test_parallel_decoding():
s += "\nIn summary," + sgl.gen("summary", max_tokens=512) s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3) ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
assert isinstance(ret["summary"], str)
def test_parallel_encoding(check_answer=True): def test_parallel_encoding(check_answer=True):
......
...@@ -21,7 +21,7 @@ class TestSRTBackend(unittest.TestCase): ...@@ -21,7 +21,7 @@ class TestSRTBackend(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3-8B-Instruct") cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
sgl.set_default_backend(cls.backend) sgl.set_default_backend(cls.backend)
@classmethod @classmethod
......
../lang/example_image.png
\ No newline at end of file
"""
First run the following command to launch the server.
Note that TinyLlama adopts different chat templates in different versions.
For v0.4, the chat template is chatml.
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 \
--port 30000 --chat-template chatml
Output example:
The capital of France is Paris.
The capital of the United States is Washington, D.C.
The capital of Canada is Ottawa.
The capital of Japan is Tokyo
"""
import argparse
import json
import openai
def test_completion(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
echo=echo,
logprobs=logprobs,
)
text = response.choices[0].text
print(response.choices[0].text)
if echo:
assert text.startswith("The capital of France is")
if logprobs:
print(response.choices[0].logprobs.top_logprobs)
assert response.choices[0].logprobs
if echo:
assert response.choices[0].logprobs.token_logprobs[0] == None
else:
assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_completion_stream(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
stream=True,
echo=echo,
logprobs=logprobs,
)
first = True
for r in response:
if first:
if echo:
assert r.choices[0].text.startswith("The capital of France is")
first = False
if logprobs:
print(
f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}",
flush=True,
)
print(r.choices[0].logprobs.top_logprobs)
else:
print(r.choices[0].text, end="", flush=True)
assert r.id
assert r.usage.prompt_tokens > 0
assert r.usage.completion_tokens > 0
assert r.usage.total_tokens > 0
print("=" * 100)
def test_chat_completion(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0,
max_tokens=32,
)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_chat_completion_image(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg"
},
},
],
},
],
temperature=0,
max_tokens=32,
)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_chat_completion_stream(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
stream=True,
)
is_first = True
for chunk in response:
if is_first:
is_first = False
assert chunk.choices[0].delta.role == "assistant"
continue
data = chunk.choices[0].delta
if not data.content:
continue
print(data.content, end="", flush=True)
print("=" * 100)
def test_regex(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
regex = (
r"""\{\n"""
+ r""" "name": "[\w]+",\n"""
+ r""" "population": [\d]+\n"""
+ r"""\}"""
)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"regex": regex},
)
text = response.choices[0].message.content
print(json.loads(text))
print("=" * 100)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
parser.add_argument(
"--test-image", action="store_true", help="Enables testing image inputs"
)
args = parser.parse_args()
test_completion(args, echo=False, logprobs=False)
test_completion(args, echo=True, logprobs=False)
test_completion(args, echo=False, logprobs=True)
test_completion(args, echo=True, logprobs=True)
test_completion(args, echo=False, logprobs=3)
test_completion(args, echo=True, logprobs=3)
test_completion_stream(args, echo=False, logprobs=False)
test_completion_stream(args, echo=True, logprobs=False)
test_completion_stream(args, echo=False, logprobs=True)
test_completion_stream(args, echo=True, logprobs=True)
test_completion_stream(args, echo=False, logprobs=3)
test_completion_stream(args, echo=True, logprobs=3)
test_chat_completion(args)
test_chat_completion_stream(args)
test_regex(args)
if args.test_image:
test_chat_completion_image(args)
""" import subprocess
First run the following command to launch the server. import time
Note that TinyLlama adopts different chat templates in different versions. import unittest
For v0.4, the chat template is chatml.
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 \
--port 30000 --chat-template chatml
Output example:
The capital of France is Paris.
The capital of the United States is Washington, D.C.
The capital of Canada is Ottawa.
The capital of Japan is Tokyo
"""
import argparse
import json
import openai import openai
import requests
def test_completion(args, echo, logprobs): from sglang.srt.utils import kill_child_process
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default", class TestOpenAIServer(unittest.TestCase):
prompt="The capital of France is",
temperature=0, @classmethod
max_tokens=32, def setUpClass(cls):
echo=echo, model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
logprobs=logprobs, port = 30000
) timeout = 300
text = response.choices[0].text
print(response.choices[0].text) command = [
if echo: "python3", "-m", "sglang.launch_server",
assert text.startswith("The capital of France is") "--model-path", model,
if logprobs: "--host", "localhost",
print(response.choices[0].logprobs.top_logprobs) "--port", str(port),
assert response.choices[0].logprobs ]
cls.process = subprocess.Popen(command, stdout=None, stderr=None)
cls.base_url = f"http://localhost:{port}/v1"
cls.model = model
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(f"{cls.base_url}/models")
if response.status_code == 200:
return
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError("Server failed to start within the timeout period.")
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_completion(self, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
prompt = "The capital of France is"
response = client.completions.create(
model=self.model,
prompt=prompt,
temperature=0.1,
max_tokens=32,
echo=echo,
logprobs=logprobs,
)
text = response.choices[0].text
if echo: if echo:
assert response.choices[0].logprobs.token_logprobs[0] == None assert text.startswith(prompt)
else:
assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_completion_stream(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
stream=True,
echo=echo,
logprobs=logprobs,
)
first = True
for r in response:
if first:
if echo:
assert r.choices[0].text.startswith("The capital of France is")
first = False
if logprobs: if logprobs:
print( assert response.choices[0].logprobs
f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}", assert isinstance(response.choices[0].logprobs.tokens[0], str)
flush=True, assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
) assert len(response.choices[0].logprobs.top_logprobs[1]) == logprobs
print(r.choices[0].logprobs.top_logprobs) if echo:
else: assert response.choices[0].logprobs.token_logprobs[0] == None
print(r.choices[0].text, end="", flush=True) else:
assert r.id assert response.choices[0].logprobs.token_logprobs[0] != None
assert r.usage.prompt_tokens > 0 assert response.id
assert r.usage.completion_tokens > 0 assert response.created
assert r.usage.total_tokens > 0 assert response.usage.prompt_tokens > 0
print("=" * 100) assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def test_chat_completion(args): def run_completion_stream(self, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url) client = openai.Client(api_key="EMPTY", base_url=self.base_url)
response = client.chat.completions.create( prompt = "The capital of France is"
model="default", generator = client.completions.create(
messages=[ model=self.model,
{"role": "system", "content": "You are a helpful AI assistant"}, prompt=prompt,
{"role": "user", "content": "What is the capital of France?"}, temperature=0.1,
], max_tokens=32,
temperature=0, echo=echo,
max_tokens=32, logprobs=logprobs,
) stream=True,
print(response.choices[0].message.content) )
assert response.id
assert response.created first = True
assert response.usage.prompt_tokens > 0 for response in generator:
assert response.usage.completion_tokens > 0 if logprobs:
assert response.usage.total_tokens > 0 assert response.choices[0].logprobs
print("=" * 100) assert isinstance(response.choices[0].logprobs.tokens[0], str)
if not (first and echo):
assert isinstance(response.choices[0].logprobs.top_logprobs[0], dict)
def test_chat_completion_image(args): #assert len(response.choices[0].logprobs.top_logprobs[0]) == logprobs
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create( if first:
model="default", if echo:
messages=[ assert response.choices[0].text.startswith(prompt)
{"role": "system", "content": "You are a helpful AI assistant"}, first = False
{
"role": "user", assert response.id
"content": [ assert response.created
{"type": "text", "text": "Describe this image"}, assert response.usage.prompt_tokens > 0
{ assert response.usage.completion_tokens > 0
"type": "image_url", assert response.usage.total_tokens > 0
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg" def test_completion(self):
}, for echo in [False, True]:
}, for logprobs in [None, 5]:
], self.run_completion(echo, logprobs)
},
], def test_completion_stream(self):
temperature=0, for echo in [True]:
max_tokens=32, for logprobs in [5]:
) self.run_completion_stream(echo, logprobs)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_chat_completion_stream(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
stream=True,
)
is_first = True
for chunk in response:
if is_first:
is_first = False
assert chunk.choices[0].delta.role == "assistant"
continue
data = chunk.choices[0].delta
if not data.content:
continue
print(data.content, end="", flush=True)
print("=" * 100)
def test_regex(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
regex = (
r"""\{\n"""
+ r""" "name": "[\w]+",\n"""
+ r""" "population": [\d]+\n"""
+ r"""\}"""
)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"regex": regex},
)
text = response.choices[0].message.content
print(json.loads(text))
print("=" * 100)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() # unittest.main(warnings="ignore")
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
parser.add_argument( t = TestOpenAIServer()
"--test-image", action="store_true", help="Enables testing image inputs" t.setUpClass()
) t.test_completion_stream()
args = parser.parse_args() t.tearDownClass()
test_completion(args, echo=False, logprobs=False)
test_completion(args, echo=True, logprobs=False)
test_completion(args, echo=False, logprobs=True)
test_completion(args, echo=True, logprobs=True)
test_completion(args, echo=False, logprobs=3)
test_completion(args, echo=True, logprobs=3)
test_completion_stream(args, echo=False, logprobs=False)
test_completion_stream(args, echo=True, logprobs=False)
test_completion_stream(args, echo=False, logprobs=True)
test_completion_stream(args, echo=True, logprobs=True)
test_completion_stream(args, echo=False, logprobs=3)
test_completion_stream(args, echo=True, logprobs=3)
test_chat_completion(args)
test_chat_completion_stream(args)
test_regex(args)
if args.test_image:
test_chat_completion_image(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment