Unverified Commit 995af5a5 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Improve the structure of CI (#911)

parent 53985645
...@@ -37,23 +37,12 @@ jobs: ...@@ -37,23 +37,12 @@ jobs:
pip install --upgrade transformers pip install --upgrade transformers
pip install accelerate pip install accelerate
- name: Test Frontend Language with SRT Backend - name: Test Frontend Language
run: | run: |
cd test/lang cd test/lang
python3 test_srt_backend.py python3 run_suite.py --suite minimal
- name: Test OpenAI API Server - name: Test Backend Runtime
run: | run: |
cd test/srt cd test/srt
python3 test_openai_server.py python3 run_suite.py --suite minimal
- name: Test Accuracy
run: |
cd test/srt
python3 test_eval_accuracy.py
python3 models/test_causal_models.py
- name: Test Frontend Language with OpenAI Backend
run: |
cd test/lang
python3 test_openai_backend.py
\ No newline at end of file
# SRT Unit Tests
### Latency Alignment
Make sure your changes do not slow down the following benchmarks
```
# single gpu
python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 32 --input-len 512 --output-len 256
python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 1 --input-len 512 --output-len 256
# multiple gpu
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 32 --input-len 8192 --output-len 1
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 1 --input-len 8100 --output-len 32
# moe model
python -m sglang.bench_latency --model-path databricks/dbrx-base --tp 8 --mem-fraction-static 0.6 --batch 4 --input-len 1024 --output-len 32
```
### High-level API
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
cd test/lang
python3 test_srt_backend.py
```
### Performance
#### MMLU
```
cd benchmark/mmlu
```
Follow README.md to download the data.
```
python3 bench_sglang.py --nsub 3
# Expected performance on A10G
# Total latency: 8.200
# Average accuracy: 0.413
```
#### GSM-8K
```
cd benchmark/gsm8k
```
Follow README.md to download the data.
```
python3 bench_sglang.py --num-q 200
# Expected performance on A10G
# Latency: 32.103
# Accuracy: 0.250
```
#### More
Please also test `benchmark/hellaswag`, `benchmark/latency_throughput`.
### More Models
#### LLaVA
```
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
```
```
cd benchmark/llava_bench
python3 bench_sglang.py
# Expected performance on A10G
# Latency: 50.031
```
## SGLang Unit Tests
```
export ANTHROPIC_API_KEY=
export OPENAI_API_KEY=
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
cd test/lang
python3 run_all.py
```
## OpenAI API server
```
cd test/srt
python test_openai_server.py
```
## Code Formatting
```
pip3 install pre-commit
cd sglang
pre-commit install
pre-commit run --all-files
```
...@@ -20,8 +20,10 @@ dependencies = [ ...@@ -20,8 +20,10 @@ dependencies = [
] ]
[project.optional-dependencies] [project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "jsonlines",
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart", "jsonlines"] "packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torch", "uvicorn", "uvloop", "zmq",
"vllm==0.5.3.post1", "outlines>=0.0.44"]
openai = ["openai>=1.0", "tiktoken"] openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"] anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"] litellm = ["litellm>=1.0.0"]
......
...@@ -10,7 +10,6 @@ import time ...@@ -10,7 +10,6 @@ import time
from sglang.test.simple_eval_common import ( from sglang.test.simple_eval_common import (
ChatCompletionSampler, ChatCompletionSampler,
download_dataset,
make_report, make_report,
set_ulimit, set_ulimit,
) )
...@@ -27,14 +26,26 @@ def run_eval(args): ...@@ -27,14 +26,26 @@ def run_eval(args):
if args.eval_name == "mmlu": if args.eval_name == "mmlu":
from sglang.test.simple_eval_mmlu import MMLUEval from sglang.test.simple_eval_mmlu import MMLUEval
dataset_path = "mmlu.csv" filename = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
eval_obj = MMLUEval(filename, args.num_examples, args.num_threads)
if not os.path.exists(dataset_path): elif args.eval_name == "math":
download_dataset( from sglang.test.simple_eval_math import MathEval
dataset_path,
"https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", equality_checker = ChatCompletionSampler(model="gpt-4-turbo")
)
eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads) filename = (
"https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv"
)
eval_obj = MathEval(
filename, equality_checker, args.num_examples, args.num_threads
)
elif args.eval_name == "gpqa":
from sglang.test.simple_eval_gpqa import GPQAEval
filename = (
"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
)
eval_obj = GPQAEval(filename, args.num_examples, args.num_threads)
elif args.eval_name == "humaneval": elif args.eval_name == "humaneval":
from sglang.test.simple_eval_humaneval import HumanEval from sglang.test.simple_eval_humaneval import HumanEval
...@@ -97,7 +108,7 @@ if __name__ == "__main__": ...@@ -97,7 +108,7 @@ if __name__ == "__main__":
) )
parser.add_argument("--eval-name", type=str, default="mmlu") parser.add_argument("--eval-name", type=str, default="mmlu")
parser.add_argument("--num-examples", type=int) parser.add_argument("--num-examples", type=int)
parser.add_argument("--num-threads", type=int, default=64) parser.add_argument("--num-threads", type=int, default=512)
set_ulimit() set_ulimit()
args = parser.parse_args() args = parser.parse_args()
......
# Adapted from https://github.com/openai/simple-evals/
"""
GPQA: A Graduate-Level Google-Proof Q&A Benchmark
David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
https://arxiv.org/abs/2311.12022
"""
import random
import re
import pandas
from sglang.test import simple_eval_common as common
from sglang.test.simple_eval_common import (
ANSWER_PATTERN_MULTICHOICE,
HTML_JINJA,
Eval,
EvalResult,
MessageList,
SamplerBase,
SingleEvalResult,
format_multichoice_question,
)
class GPQAEval(Eval):
def __init__(
self,
filename: str,
num_examples: int | None,
num_threads: int,
n_repeats: int = 1,
):
df = pandas.read_csv(filename)
examples = [row.to_dict() for _, row in df.iterrows()]
rng = random.Random(0)
if num_examples:
assert n_repeats == 1, "n_repeats only supported for num_examples"
examples = rng.sample(examples, num_examples)
examples = examples * n_repeats
examples = [
example | {"permutation": rng.sample(range(4), 4)} for example in examples
]
self.examples = examples
self.n_repeats = n_repeats
self.num_threads = num_threads
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
choices = [
row["Correct Answer"],
row["Incorrect Answer 1"],
row["Incorrect Answer 2"],
row["Incorrect Answer 3"],
]
choices = [choices[i] for i in row["permutation"]]
correct_index = choices.index(row["Correct Answer"])
correct_answer = "ABCD"[correct_index]
choices_dict = dict(
A=choices[0],
B=choices[1],
C=choices[2],
D=choices[3],
Question=row["Question"],
)
prompt_messages = [
sampler._pack_message(
content=format_multichoice_question(choices_dict), role="user"
)
]
response_text = sampler(prompt_messages)
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
extracted_answer = match.group(1) if match else None
score = 1.0 if extracted_answer == correct_answer else 0.0
html = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=correct_answer,
extracted_answer=extracted_answer,
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(
html=html,
score=score,
convo=convo,
metrics={"chars": len(response_text)},
)
results = common.map_with_progress(fn, self.examples, self.num_threads)
return common.aggregate_results(results)
# Adapted from https://github.com/openai/simple-evals/
"""
Measuring Mathematical Problem Solving With the MATH Dataset
Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt
https://arxiv.org/abs/2103.03874
"""
import random
import re
import pandas
from sglang.test import simple_eval_common as common
from sglang.test.simple_eval_common import (
ANSWER_PATTERN,
HTML_JINJA,
Eval,
EvalResult,
SamplerBase,
SingleEvalResult,
check_equality,
)
QUERY_TEMPLATE = """
Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
{Question}
Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
""".strip()
class MathEval(Eval):
def __init__(
self,
filename: str,
equality_checker: SamplerBase,
num_examples: int | None,
num_threads: int,
):
df = pandas.read_csv(filename)
examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples:
examples = random.Random(0).sample(examples, num_examples)
self.examples = examples
self.equality_checker = equality_checker
self.num_threads = num_threads
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
prompt_messages = [
sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
]
response_text = sampler(prompt_messages)
match = re.search(ANSWER_PATTERN, response_text)
extracted_answer = match.group(1) if match else None
score = float(
check_equality(self.equality_checker, row["Answer"], extracted_answer)
)
html = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=row["Answer"],
extracted_answer=extracted_answer,
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(html=html, score=score, convo=convo)
results = common.map_with_progress(fn, self.examples, self.num_threads)
return common.aggregate_results(results)
"""Common utilities for testing and benchmarking""" """Common utilities for testing and benchmarking"""
import argparse
import asyncio import asyncio
import multiprocessing
import subprocess import subprocess
import threading
import time import time
import unittest
from functools import partial from functools import partial
from typing import Callable, Optional
import numpy as np import numpy as np
import requests import requests
...@@ -247,7 +252,7 @@ async def call_select_lmql(context, choices, temperature=0, max_len=4096, model= ...@@ -247,7 +252,7 @@ async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=
return choices.index(answer) return choices.index(answer)
def add_common_other_args_and_parse(parser): def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
parser.add_argument("--parallel", type=int, default=64) parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None) parser.add_argument("--port", type=int, default=None)
...@@ -286,7 +291,7 @@ def add_common_other_args_and_parse(parser): ...@@ -286,7 +291,7 @@ def add_common_other_args_and_parse(parser):
return args return args
def add_common_sglang_args_and_parse(parser): def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
parser.add_argument("--parallel", type=int, default=64) parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000) parser.add_argument("--port", type=int, default=30000)
...@@ -296,7 +301,7 @@ def add_common_sglang_args_and_parse(parser): ...@@ -296,7 +301,7 @@ def add_common_sglang_args_and_parse(parser):
return args return args
def select_sglang_backend(args): def select_sglang_backend(args: argparse.Namespace):
if args.backend.startswith("srt"): if args.backend.startswith("srt"):
if args.backend == "srt-no-parallel": if args.backend == "srt-no-parallel":
global_config.enable_parallel_decoding = False global_config.enable_parallel_decoding = False
...@@ -309,7 +314,7 @@ def select_sglang_backend(args): ...@@ -309,7 +314,7 @@ def select_sglang_backend(args):
return backend return backend
def _get_call_generate(args): def _get_call_generate(args: argparse.Namespace):
if args.backend == "lightllm": if args.backend == "lightllm":
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate") return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm": elif args.backend == "vllm":
...@@ -336,7 +341,7 @@ def _get_call_generate(args): ...@@ -336,7 +341,7 @@ def _get_call_generate(args):
raise ValueError(f"Invalid backend: {args.backend}") raise ValueError(f"Invalid backend: {args.backend}")
def _get_call_select(args): def _get_call_select(args: argparse.Namespace):
if args.backend == "lightllm": if args.backend == "lightllm":
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate") return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm": elif args.backend == "vllm":
...@@ -359,7 +364,7 @@ def _get_call_select(args): ...@@ -359,7 +364,7 @@ def _get_call_select(args):
raise ValueError(f"Invalid backend: {args.backend}") raise ValueError(f"Invalid backend: {args.backend}")
def get_call_generate(args): def get_call_generate(args: argparse.Namespace):
call_generate = _get_call_generate(args) call_generate = _get_call_generate(args)
def func(*args, **kwargs): def func(*args, **kwargs):
...@@ -372,7 +377,7 @@ def get_call_generate(args): ...@@ -372,7 +377,7 @@ def get_call_generate(args):
return func return func
def get_call_select(args): def get_call_select(args: argparse.Namespace):
call_select = _get_call_select(args) call_select = _get_call_select(args)
def func(*args, **kwargs): def func(*args, **kwargs):
...@@ -385,7 +390,12 @@ def get_call_select(args): ...@@ -385,7 +390,12 @@ def get_call_select(args):
return func return func
def popen_launch_server(model, port, timeout, *args): def popen_launch_server(
model: str, base_url: str, timeout: float, other_args: tuple = ()
):
_, host, port = base_url.split(":")
host = host[2:]
command = [ command = [
"python3", "python3",
"-m", "-m",
...@@ -393,21 +403,81 @@ def popen_launch_server(model, port, timeout, *args): ...@@ -393,21 +403,81 @@ def popen_launch_server(model, port, timeout, *args):
"--model-path", "--model-path",
model, model,
"--host", "--host",
"localhost", host,
"--port", "--port",
str(port), port,
*args, *other_args,
] ]
process = subprocess.Popen(command, stdout=None, stderr=None) process = subprocess.Popen(command, stdout=None, stderr=None)
base_url = f"http://localhost:{port}/v1"
start_time = time.time() start_time = time.time()
while time.time() - start_time < timeout: while time.time() - start_time < timeout:
try: try:
response = requests.get(f"{base_url}/models") response = requests.get(f"{base_url}/v1/models")
if response.status_code == 200: if response.status_code == 200:
return process return process
except requests.RequestException: except requests.RequestException:
pass pass
time.sleep(10) time.sleep(10)
raise TimeoutError("Server failed to start within the timeout period.") raise TimeoutError("Server failed to start within the timeout period.")
def run_with_timeout(
func: Callable,
args: tuple = (),
kwargs: Optional[dict] = None,
timeout: float = None,
):
"""Run a function with timeout."""
ret_value = []
def _target_func():
ret_value.append(func(*args, **(kwargs or {})))
t = threading.Thread(target=_target_func)
t.start()
t.join(timeout=timeout)
if t.is_alive():
raise TimeoutError()
if not ret_value:
raise RuntimeError()
return ret_value[0]
def run_unittest_files(files: list[str], timeout_per_file: float):
tic = time.time()
success = True
for filename in files:
def func():
print(f"\n\nRun {filename}\n\n")
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
p = multiprocessing.Process(target=func)
def run_one_file():
p.start()
p.join()
try:
run_with_timeout(run_one_file, timeout=timeout_per_file)
if p.exitcode != 0:
success = False
break
except TimeoutError:
p.terminate()
time.sleep(5)
print(
"\nTimeout after {timeout_per_file} seconds when running {filename}\n"
)
return False
if success:
print(f"Success. Time elapsed: {time.time() - tic:.2f}s")
else:
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
return 0 if success else -1
...@@ -12,6 +12,7 @@ import urllib.request ...@@ -12,6 +12,7 @@ import urllib.request
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from io import BytesIO from io import BytesIO
from json import dumps from json import dumps
from typing import Union
import numpy as np import numpy as np
import requests import requests
...@@ -25,7 +26,7 @@ def get_exception_traceback(): ...@@ -25,7 +26,7 @@ def get_exception_traceback():
return err_str return err_str
def is_same_type(values): def is_same_type(values: list):
"""Return whether the elements in values are of the same type.""" """Return whether the elements in values are of the same type."""
if len(values) <= 1: if len(values) <= 1:
return True return True
...@@ -45,7 +46,7 @@ def read_jsonl(filename: str): ...@@ -45,7 +46,7 @@ def read_jsonl(filename: str):
return rets return rets
def dump_state_text(filename, states, mode="w"): def dump_state_text(filename: str, states: list, mode: str = "w"):
"""Dump program state in a text file.""" """Dump program state in a text file."""
from sglang.lang.interpreter import ProgramState from sglang.lang.interpreter import ProgramState
...@@ -105,7 +106,7 @@ def http_request( ...@@ -105,7 +106,7 @@ def http_request(
return HttpResponse(e) return HttpResponse(e)
def encode_image_base64(image_path): def encode_image_base64(image_path: Union[str, bytes]):
"""Encode an image in base64.""" """Encode an image in base64."""
if isinstance(image_path, str): if isinstance(image_path, str):
with open(image_path, "rb") as image_file: with open(image_path, "rb") as image_file:
...@@ -144,7 +145,7 @@ def encode_frame(frame): ...@@ -144,7 +145,7 @@ def encode_frame(frame):
return frame_bytes return frame_bytes
def encode_video_base64(video_path, num_frames=16): def encode_video_base64(video_path: str, num_frames: int = 16):
import cv2 # pip install opencv-python-headless import cv2 # pip install opencv-python-headless
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
...@@ -190,7 +191,7 @@ def encode_video_base64(video_path, num_frames=16): ...@@ -190,7 +191,7 @@ def encode_video_base64(video_path, num_frames=16):
return video_base64 return video_base64
def _is_chinese_char(cp): def _is_chinese_char(cp: int):
"""Checks whether CP is the codepoint of a CJK character.""" """Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block: # This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
...@@ -215,7 +216,7 @@ def _is_chinese_char(cp): ...@@ -215,7 +216,7 @@ def _is_chinese_char(cp):
return False return False
def find_printable_text(text): def find_printable_text(text: str):
"""Returns the longest printable substring of text that contains only entire words.""" """Returns the longest printable substring of text that contains only entire words."""
# Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99 # Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99
...@@ -234,26 +235,7 @@ def find_printable_text(text): ...@@ -234,26 +235,7 @@ def find_printable_text(text):
return text[: text.rfind(" ") + 1] return text[: text.rfind(" ") + 1]
def run_with_timeout(func, args=(), kwargs=None, timeout=None): def graceful_registry(sub_module_name: str):
"""Run a function with timeout."""
ret_value = []
def _target_func():
ret_value.append(func(*args, **(kwargs or {})))
t = threading.Thread(target=_target_func)
t.start()
t.join(timeout=timeout)
if t.is_alive():
raise TimeoutError()
if not ret_value:
raise RuntimeError()
return ret_value[0]
def graceful_registry(sub_module_name):
def graceful_shutdown(signum, frame): def graceful_shutdown(signum, frame):
logger.info( logger.info(
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..." f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
...@@ -265,7 +247,9 @@ def graceful_registry(sub_module_name): ...@@ -265,7 +247,9 @@ def graceful_registry(sub_module_name):
class LazyImport: class LazyImport:
def __init__(self, module_name, class_name): """Lazy import to make `import sglang` run faster."""
def __init__(self, module_name: str, class_name: str):
self.module_name = module_name self.module_name = module_name
self.class_name = class_name self.class_name = class_name
self._module = None self._module = None
...@@ -276,7 +260,7 @@ class LazyImport: ...@@ -276,7 +260,7 @@ class LazyImport:
self._module = getattr(module, self.class_name) self._module = getattr(module, self.class_name)
return self._module return self._module
def __getattr__(self, name): def __getattr__(self, name: str):
module = self._load() module = self._load()
return getattr(module, name) return getattr(module, name)
......
isort python
black python
isort test
black test
isort benchmark
black benchmark
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment