"examples/vscode:/vscode.git/clone" did not exist on "e47cc1fc1a89a5375c322d296cd122fe71ab859f"
Unverified Commit e4d68afc authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Many cleanup (#1357)

parent c9b75917
## Download data
```
bash download_data.sh
```
## Run benchmark
### Benchmark sglang
......
......@@ -10,7 +10,7 @@ import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
INVALID = -9999999
......@@ -41,24 +41,28 @@ def get_answer_value(answer_str):
def main(args):
lines = read_jsonl(args.data_path)
# Select backend
call_generate = get_call_generate(args)
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
labels = []
for i in range(len(lines[: args.num_questions])):
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
states = [None] * len(labels)
# Select backend
call_generate = get_call_generate(args)
# Run requests
if args.backend != "lmql":
# Use thread pool
......@@ -113,11 +117,13 @@ def main(args):
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
# Write results
# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
......@@ -138,7 +144,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=5)
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser)
......
......@@ -6,11 +6,12 @@ import time
import numpy as np
from sglang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
INVALID = -9999999
......@@ -41,15 +42,22 @@ def get_answer_value(answer_str):
def main(args):
lines = read_jsonl(args.data_path)
# Select backend
set_default_backend(select_sglang_backend(args))
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
labels = []
for i in range(len(lines[: args.num_questions])):
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
......@@ -72,15 +80,11 @@ def main(args):
########## SGL Program End ##########
#####################################
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.time()
states = few_shot_gsm8k.run_batch(
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
......@@ -96,11 +100,20 @@ def main(args):
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
# Compute speed
num_output_tokens = sum(
s.get_meta_info("answer")["completion_tokens"] for s in states
)
output_throughput = num_output_tokens / latency
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
print(f"Output throughput: {output_throughput:.3f} token/s")
# Write results
# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
......@@ -121,7 +134,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=5)
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)
......
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
\ No newline at end of file
## Download data
```
wget https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl
```
## Run benchmark
### Benchmark sglang
......
......@@ -8,7 +8,7 @@ import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select
from sglang.utils import read_jsonl
from sglang.utils import download_and_cache_file, read_jsonl
def get_one_example(lines, i, include_answer):
......@@ -26,25 +26,29 @@ def get_few_shot_examples(lines, k):
def main(args):
lines = read_jsonl(args.data_path)
# Select backend
call_select = get_call_select(args)
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
choices = []
labels = []
for i in range(len(lines[: args.num_questions])):
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
preds = [None] * len(labels)
# Select backend
call_select = get_call_select(args)
# Run requests
if args.backend != "lmql":
# Use thread pool
......@@ -65,7 +69,6 @@ def main(args):
total=len(questions),
)
)
else:
# Use asyncio
async def batched_call(batch_size):
......@@ -108,7 +111,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=20)
parser.add_argument("--num-shots", type=int, default=20)
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser)
......
......@@ -4,11 +4,12 @@ import time
import numpy as np
from sglang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import read_jsonl
from sglang.utils import download_and_cache_file, read_jsonl
def get_one_example(lines, i, include_answer):
......@@ -26,16 +27,23 @@ def get_few_shot_examples(lines, k):
def main(args):
lines = read_jsonl(args.data_path)
# Select backend
set_default_backend(select_sglang_backend(args))
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
choices = []
labels = []
for i in range(len(lines[: args.num_questions])):
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
......@@ -56,15 +64,11 @@ def main(args):
########## SGL Program End ##########
#####################################
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.time()
rets = few_shot_hellaswag.run_batch(
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
......@@ -95,7 +99,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=20)
parser.add_argument("--num-shots", type=int, default=20)
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)
......
......@@ -7,6 +7,7 @@ python3 srt_example_llava_v.py
import argparse
import csv
import json
import os
import time
......@@ -223,7 +224,7 @@ if __name__ == "__main__":
tokenizer_path=tokenizer_path,
port=cur_port,
additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
model_override_args=model_override_args,
json_model_override_args=json.dumps(model_override_args),
tp_size=1,
)
sgl.set_default_backend(runtime)
......
......@@ -298,34 +298,41 @@ class BenchmarkMetrics:
median_e2e_latency_ms: float
default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
def download_sharegpt_dataset(path):
url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
def download_and_cache_file(url: str, filename: Optional[str] = None):
"""Read and cache a file from a url."""
if filename is None:
filename = os.path.join("/tmp", url.split("/")[-1])
print(f"Downloading dataset from {url}")
try:
response = requests.get(url, stream=True)
response.raise_for_status()
# Check if the cache file already exists
if os.path.exists(filename):
return filename
print(f"Downloading from {url} to {filename}")
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors
with open(path, "wb") as f, tqdm(
desc="Downloading",
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(block_size):
size = f.write(data)
progress_bar.update(size)
# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB
print(f"Dataset downloaded and saved to {path}")
except requests.RequestException as e:
raise Exception(f"Failed to download dataset: {e}")
# Use tqdm to display the progress bar
with open(filename, "wb") as f, tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
return filename
def sample_sharegpt_requests(
......@@ -338,13 +345,8 @@ def sample_sharegpt_requests(
raise ValueError("output_len too small")
# Download sharegpt if necessary
if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
download_sharegpt_dataset(default_sharegpt_path)
dataset_path = default_sharegpt_path
else:
dataset_path = (
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
)
if not os.path.isfile(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
with open(dataset_path) as f:
......@@ -412,15 +414,8 @@ def sample_random_requests(
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
if not os.path.isfile(dataset_path) and not os.path.isfile(
default_sharegpt_path
):
download_sharegpt_dataset(default_sharegpt_path)
dataset_path = default_sharegpt_path
else:
dataset_path = (
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
)
if not os.path.isfile(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
with open(dataset_path) as f:
......
......@@ -9,10 +9,9 @@ from sglang.srt.utils import kill_child_process
if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])
model_override_args = server_args.json_model_override_args
try:
launch_server(server_args, model_override_args=model_override_args)
launch_server(server_args)
except Exception as e:
raise e
finally:
......
"""Launch the inference server for Llava-video model."""
import json
import sys
from sglang.srt.server import launch_server, prepare_server_args
......@@ -19,5 +20,6 @@ if __name__ == "__main__":
model_override_args["model_max_length"] = 4096 * 2
if "34b" in server_args.model_path.lower():
model_override_args["image_token_index"] = 64002
server_args.json_model_override_args = json.dumps(model_override_args)
launch_server(server_args, model_override_args, None)
launch_server(server_args)
......@@ -16,6 +16,7 @@ limitations under the License.
"""Cache for the compressed finite state machine."""
from outlines.fsm.json_schema import build_regex_from_schema
from transformers import AutoTokenizer
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_tool_cache import BaseToolCache
......@@ -28,12 +29,9 @@ class FSMCache(BaseToolCache):
tokenizer_args_dict,
enable=True,
skip_tokenizer_init=False,
json_schema_mode=False,
):
super().__init__(enable=enable)
self.json_schema_mode = json_schema_mode
if (
skip_tokenizer_init
or tokenizer_path.endswith(".json")
......@@ -42,44 +40,37 @@ class FSMCache(BaseToolCache):
# Do not support TiktokenTokenizer or SentencePieceTokenizer
return
from importlib.metadata import version
tokenizer_args_dict.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
try:
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
except AttributeError:
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
origin_pad_token_id = tokenizer.pad_token_id
if version("outlines") >= "0.0.35":
from transformers import AutoTokenizer
def fset(self, value):
self._value = value
tokenizer_args_dict.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, **tokenizer_args_dict
type(tokenizer).pad_token_id = property(
fget=type(tokenizer).pad_token_id.fget, fset=fset
)
try:
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
except AttributeError:
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
origin_pad_token_id = tokenizer.pad_token_id
def fset(self, value):
self._value = value
type(tokenizer).pad_token_id = property(
fget=type(tokenizer).pad_token_id.fget, fset=fset
)
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token = (
self.outlines_tokenizer.tokenizer.pad_token
)
self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab()
)
else:
self.outlines_tokenizer = TransformerTokenizer(
tokenizer_path, **tokenizer_args_dict
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token = (
self.outlines_tokenizer.tokenizer.pad_token
)
self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab()
)
def init_value(self, value):
if self.json_schema_mode:
regex = build_regex_from_schema(value, whitespace_pattern=r"[\n\t ]*")
return RegexGuide(regex, self.outlines_tokenizer), regex
def init_value(self, key):
key_type, key_string = key
if key_type == "json":
regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
elif key_type == "regex":
regex = key_string
else:
return RegexGuide(value, self.outlines_tokenizer)
raise ValueError(f"Invalid key_type: {key_type}")
return RegexGuide(regex, self.outlines_tokenizer), regex
......@@ -71,12 +71,10 @@ class ControllerMulti:
self,
server_args: ServerArgs,
port_args: PortArgs,
model_override_args,
):
# Parse args
self.server_args = server_args
self.port_args = port_args
self.model_override_args = model_override_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method
)
......@@ -114,7 +112,6 @@ class ControllerMulti:
self.server_args,
self.port_args,
pipe_controller_writer,
self.model_override_args,
True,
gpu_ids,
dp_worker_id,
......@@ -189,14 +186,13 @@ def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
model_override_args: dict,
):
"""Start a controller process."""
configure_logger(server_args)
try:
controller = ControllerMulti(server_args, port_args, model_override_args)
controller = ControllerMulti(server_args, port_args)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
......
......@@ -40,7 +40,6 @@ class ControllerSingle:
self,
server_args: ServerArgs,
port_args: PortArgs,
model_override_args: dict,
gpu_ids: List[int],
is_data_parallel_worker: bool,
dp_worker_id: int,
......@@ -76,7 +75,6 @@ class ControllerSingle:
tp_rank_range,
server_args,
port_args.nccl_ports[dp_worker_id],
model_override_args,
)
# Launch tp rank 0
......@@ -85,7 +83,6 @@ class ControllerSingle:
0,
server_args,
port_args.nccl_ports[dp_worker_id],
model_override_args,
)
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
......@@ -126,7 +123,6 @@ def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer: multiprocessing.connection.Connection,
model_override_args: dict,
is_data_parallel_worker: bool = False,
gpu_ids: List[int] = None,
dp_worker_id: int = None,
......@@ -149,7 +145,6 @@ def start_controller_process(
controller = ControllerSingle(
server_args,
port_args,
model_override_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
......
......@@ -18,6 +18,7 @@ limitations under the License.
import asyncio
import concurrent.futures
import dataclasses
import json
import logging
import multiprocessing as mp
import os
......@@ -77,7 +78,6 @@ class TokenizerManager:
self,
server_args: ServerArgs,
port_args: PortArgs,
model_override_args: dict = None,
):
self.server_args = server_args
......@@ -95,7 +95,7 @@ class TokenizerManager:
self.hf_config = get_config(
self.model_path,
trust_remote_code=server_args.trust_remote_code,
model_override_args=model_override_args,
model_override_args=json.loads(server_args.json_model_override_args),
)
self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding
......
......@@ -15,13 +15,14 @@ limitations under the License.
"""A tensor parallel worker."""
import json
import logging
import multiprocessing
import os
import pickle
import time
import warnings
from typing import Any, List, Optional, Union
from typing import Any, List, Optional
import torch
import torch.distributed
......@@ -66,6 +67,7 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
......@@ -76,11 +78,10 @@ class ModelTpServer:
tp_rank: int,
server_args: ServerArgs,
nccl_port: int,
model_override_args: dict,
):
suppress_other_loggers()
# Copy arguments
# Parse arguments
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
......@@ -93,9 +94,8 @@ class ModelTpServer:
server_args.model_path,
server_args.trust_remote_code,
context_length=server_args.context_length,
model_override_args=model_override_args,
model_override_args=json.loads(server_args.json_model_override_args),
)
self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
......@@ -136,7 +136,7 @@ class ModelTpServer:
self.max_total_num_tokens - 1,
)
# Sync random seed
# Sync random seed across TP workers
server_args.random_seed = broadcast_recv_input(
[server_args.random_seed],
self.tp_rank,
......@@ -144,7 +144,7 @@ class ModelTpServer:
)[0]
set_random_seed(server_args.random_seed)
# Print info
# Print debug info
logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
......@@ -181,7 +181,7 @@ class ModelTpServer:
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
# Chunked prefill
# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
self.is_mixed_chunk = (
......@@ -197,16 +197,6 @@ class ModelTpServer:
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
json_schema_mode=False,
)
self.json_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
json_schema_mode=True,
)
self.jump_forward_cache = JumpForwardCache()
......@@ -227,11 +217,12 @@ class ModelTpServer:
try:
# Recv requests
for recv_req in recv_reqs:
if isinstance(
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
):
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
self.handle_embedding_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
......@@ -331,57 +322,56 @@ class ModelTpServer:
def handle_generate_request(
self,
recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.tokenizer = self.tokenizer
req.sampling_params = recv_req.sampling_params
if self.model_runner.is_generation:
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
# Use image hash as fake token_ids, which is then used
# for prefix matching
image_hash = hash(tuple(recv_req.image_hashes))
req.pad_value = [
(image_hash) % self.model_config.vocab_size,
(image_hash >> 16) % self.model_config.vocab_size,
(image_hash >> 32) % self.model_config.vocab_size,
(image_hash >> 64) % self.model_config.vocab_size,
]
req.image_sizes = recv_req.image_sizes
(
req.origin_input_ids,
req.image_offsets,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values,
req.image_sizes,
)
# Only when pixel values is not None we have modalities
req.modalities = recv_req.modalites
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
# Init regex fsm fron json
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
# Use image hash as fake token_ids, which is then used
# for prefix matching
image_hash = hash(tuple(recv_req.image_hashes))
req.pad_value = [
(image_hash) % self.model_config.vocab_size,
(image_hash >> 16) % self.model_config.vocab_size,
(image_hash >> 32) % self.model_config.vocab_size,
(image_hash >> 64) % self.model_config.vocab_size,
]
req.image_sizes = recv_req.image_sizes
(
req.origin_input_ids,
req.image_offsets,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values,
req.image_sizes,
)
# Only when pixel values is not None we have modalities
req.modalities = recv_req.modalites
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
# Init regex FSM
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
):
if req.sampling_params.json_schema is not None:
req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
req.sampling_params.json_schema
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
("json", req.sampling_params.json_schema)
)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
computed_regex_string
)
# Init regex fsm
elif req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
req.sampling_params.regex
)
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
("regex", req.sampling_params.regex)
)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
computed_regex_string
)
# Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len:
......@@ -390,16 +380,32 @@ class ModelTpServer:
"the max context length. Truncated!!!"
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_input_len - 1 - len(req.origin_input_ids),
)
if self.model_runner.is_generation:
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_input_len - 1 - len(req.origin_input_ids),
self.waiting_queue.append(req)
def handle_embedding_request(
self,
recv_req: TokenizedEmbeddingReqInput,
):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.tokenizer = self.tokenizer
req.sampling_params = recv_req.sampling_params
# Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len:
logger.warn(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
self.waiting_queue.append(req)
......@@ -892,7 +898,6 @@ def run_tp_server(
tp_rank: int,
server_args: ServerArgs,
nccl_port: int,
model_override_args: dict,
):
"""Run a tensor parallel model server."""
configure_logger(server_args, prefix=f" TP{tp_rank}")
......@@ -903,7 +908,6 @@ def run_tp_server(
tp_rank,
server_args,
nccl_port,
model_override_args,
)
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
......@@ -920,14 +924,13 @@ def launch_tp_servers(
tp_rank_range: List[int],
server_args: ServerArgs,
nccl_port: int,
model_override_args: dict,
):
"""Launch multiple tensor parallel servers."""
procs = []
for i in tp_rank_range:
proc = multiprocessing.Process(
target=run_tp_server,
args=(gpu_ids[i], i, server_args, nccl_port, model_override_args),
args=(gpu_ids[i], i, server_args, nccl_port),
)
proc.start()
procs.append(proc)
......
......@@ -18,6 +18,7 @@ limitations under the License.
import gc
import importlib
import importlib.resources
import json
import logging
import pkgutil
from functools import lru_cache
......
......@@ -272,7 +272,6 @@ async def retrieve_file_content(file_id: str):
def launch_server(
server_args: ServerArgs,
model_override_args: Optional[dict] = None,
pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
"""Launch an HTTP server."""
......@@ -317,7 +316,6 @@ def launch_server(
tp_rank_range,
server_args,
ports[3],
model_override_args,
)
try:
......@@ -328,7 +326,7 @@ def launch_server(
return
# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args)
tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
......@@ -341,7 +339,7 @@ def launch_server(
proc_controller = mp.Process(
target=start_controller_process,
args=(server_args, port_args, pipe_controller_writer, model_override_args),
args=(server_args, port_args, pipe_controller_writer),
)
proc_controller.start()
......@@ -501,7 +499,6 @@ class Runtime:
def __init__(
self,
log_level: str = "error",
model_override_args: Optional[dict] = None,
*args,
**kwargs,
):
......@@ -525,7 +522,7 @@ class Runtime:
proc = mp.Process(
target=launch_server,
args=(self.server_args, model_override_args, pipe_writer),
args=(self.server_args, pipe_writer),
)
proc.start()
pipe_writer.close()
......
......@@ -76,6 +76,14 @@ class ServerArgs:
dp_size: int = 1
load_balance_method: str = "round_robin"
# Distributed args
nccl_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
# Model override args in JSON
json_model_override_args: str = "{}"
# Optimization/debug options
disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
......@@ -91,14 +99,6 @@ class ServerArgs:
enable_mla: bool = False
triton_attention_reduce_in_fp32: bool = False
# Distributed args
nccl_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
# Model override args in JSON
json_model_override_args: Optional[dict] = None
def __post_init__(self):
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
......@@ -385,6 +385,14 @@ class ServerArgs:
)
parser.add_argument("--node-rank", type=int, help="The node rank.")
# Model override args
parser.add_argument(
"--json-model-override-args",
type=str,
help="A dictionary in JSON string format used to override default model configurations.",
default=ServerArgs.json_model_override_args,
)
# Optimization/debug options
parser.add_argument(
"--disable-flashinfer",
......@@ -459,22 +467,10 @@ class ServerArgs:
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
# Model override args
parser.add_argument(
"--json-model-override-args",
type=str,
help="A dictionary in JSON string format used to override default model configurations.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
args.dp_size = args.data_parallel_size
args.json_model_override_args = (
json.loads(args.json_model_override_args)
if args.json_model_override_args
else None
)
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
......@@ -498,7 +494,7 @@ class ServerArgs:
self.disable_flashinfer = False
def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
def prepare_server_args(argv: List[str]) -> ServerArgs:
"""
Prepare the server arguments from the command line arguments.
......@@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
"""
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(args)
raw_args = parser.parse_args(argv)
server_args = ServerArgs.from_cli_args(raw_args)
return server_args
......
"""
Run few-shot GSM-8K evaluation.
Usage:
python3 -m sglang.test.few_shot_gsm8k --num-questions 200
"""
import argparse
import ast
import re
import time
import numpy as np
from sglang.api import set_default_backend
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
INVALID = -9999999
def get_one_example(lines, i, include_answer):
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def main(args):
# Select backend
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q} for q in questions]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_gsm8k(s, question):
s += few_shot_examples + question
s += sgl.gen(
"answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
)
#####################################
########## SGL Program End ##########
#####################################
# Run requests
tic = time.time()
states = few_shot_gsm8k.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic
preds = []
for i in range(len(states)):
preds.append(get_answer_value(states[i]["answer"]))
# print(f"{preds=}")
# print(f"{labels=}")
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
# Compute speed
num_output_tokens = sum(
s.get_meta_info("answer")["completion_tokens"] for s in states
)
output_throughput = num_output_tokens / latency
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
print(f"Output throughput: {output_throughput:.3f} token/s")
# Dump results
dump_state_text("tmp_output_gsm8k.txt", states)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
parser.add_argument("--parallel", type=int, default=128)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment