Unverified Commit 2558d6a6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the default arguments of bench_offline_throughput.py & simplify detokenizer manager (#2042)

parent 29ebe3df
""" """
Benchmark the throughput of using the offline LLM engine. Benchmark the throughput of using the offline LLM engine.
This script does not launch a server. This script does not launch a server.
It accepts the same arguments as launch_server.py and additional benchmark arguments It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
# Usage # Usage
## Sharegpt dataset with default args ## Sharegpt dataset with default args
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct
## Random dataset with default args ## Random dataset with default args
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random
## Shared prefix dataset with default args ## Shared prefix dataset with default args
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --dataset-name generated-shared-prefix python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name generated-shared-prefix
## Sharegpt dataset on runtime backend ## Sharegpt dataset on runtime backend
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --backend runtime python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --backend runtime
""" """
import argparse import argparse
...@@ -23,7 +23,7 @@ import json ...@@ -23,7 +23,7 @@ import json
import logging import logging
import random import random
import time import time
from typing import List, Tuple from typing import List, Optional, Tuple
import numpy as np import numpy as np
...@@ -45,14 +45,15 @@ class BenchArgs: ...@@ -45,14 +45,15 @@ class BenchArgs:
dataset_name: str = "sharegpt" dataset_name: str = "sharegpt"
dataset_path: str = "" dataset_path: str = ""
num_prompts: int = 1000 num_prompts: int = 1000
sharegpt_output_len: int = 256 sharegpt_output_len: Optional[int] = None
random_input_len: int = 256 random_input_len: int = 1024
random_output_len: int = 256 random_output_len: int = 1024
random_range_ratio: float = 0.0 random_range_ratio: float = 0.0
gen_num_groups: int = 8 gen_num_groups: int = 64
gen_prompts_per_group: int = 16 gen_prompts_per_group: int = 16
gen_system_prompt_len: int = 128 gen_system_prompt_len: int = 2048
gen_question_len: int = 256 gen_question_len: int = 128
gen_output_len: int = 256
disable_ignore_eos: bool = False disable_ignore_eos: bool = False
seed: int = 1 seed: int = 1
...@@ -129,6 +130,12 @@ class BenchArgs: ...@@ -129,6 +130,12 @@ class BenchArgs:
default=BenchArgs.gen_question_len, default=BenchArgs.gen_question_len,
help="Question length, used" "only for generate-shared-prefix", help="Question length, used" "only for generate-shared-prefix",
) )
parser.add_argument(
"--gen-output-len",
type=int,
default=BenchArgs.gen_output_len,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
parser.add_argument( parser.add_argument(
"--disable-ignore-eos", "--disable-ignore-eos",
type=bool, type=bool,
...@@ -139,12 +146,8 @@ class BenchArgs: ...@@ -139,12 +146,8 @@ class BenchArgs:
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to case the args into correct types. attrs = [attr.name for attr in dataclasses.fields(cls)]
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs})
print(attrs)
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
def throughput_test_once( def throughput_test_once(
...@@ -224,6 +227,7 @@ def throughput_test( ...@@ -224,6 +227,7 @@ def throughput_test(
random.seed(bench_args.seed) random.seed(bench_args.seed)
np.random.seed(bench_args.seed) np.random.seed(bench_args.seed)
# Read dataset
input_requests = get_dataset(bench_args, tokenizer) input_requests = get_dataset(bench_args, tokenizer)
warmup_requests = sample_random_requests( warmup_requests = sample_random_requests(
......
...@@ -1241,10 +1241,12 @@ if __name__ == "__main__": ...@@ -1241,10 +1241,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--random-input-len", "--random-input-len",
type=int, type=int,
default=1024,
help="Number of input tokens per request, used only for random dataset.", help="Number of input tokens per request, used only for random dataset.",
) )
parser.add_argument( parser.add_argument(
"--random-output-len", "--random-output-len",
default=1024,
type=int, type=int,
help="Number of output tokens per request, used only for random dataset.", help="Number of output tokens per request, used only for random dataset.",
) )
......
...@@ -100,20 +100,6 @@ class DetokenizerManager: ...@@ -100,20 +100,6 @@ class DetokenizerManager:
if isinstance(recv_obj, BatchEmbeddingOut): if isinstance(recv_obj, BatchEmbeddingOut):
# If it is embedding model, no detokenization is needed. # If it is embedding model, no detokenization is needed.
self.send_to_tokenizer.send_pyobj(
BatchEmbeddingOut(
rids=recv_obj.rids,
embeddings=recv_obj.embeddings,
meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason,
)
)
continue
elif isinstance(recv_obj, UpdateWeightReqOutput):
# If it is a weight update request, no detokenization is needed.
self.send_to_tokenizer.send_pyobj(recv_obj)
continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
self.send_to_tokenizer.send_pyobj(recv_obj) self.send_to_tokenizer.send_pyobj(recv_obj)
continue continue
else: else:
......
...@@ -114,6 +114,9 @@ class Scheduler: ...@@ -114,6 +114,9 @@ class Scheduler:
self.recv_from_tokenizer = get_zmq_socket( self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name context, zmq.PULL, port_args.scheduler_input_ipc_name
) )
self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name
)
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
# Directly send to the tokenizer/api # Directly send to the tokenizer/api
...@@ -127,6 +130,7 @@ class Scheduler: ...@@ -127,6 +130,7 @@ class Scheduler:
) )
else: else:
self.recv_from_tokenizer = None self.recv_from_tokenizer = None
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
# Init tokenizer # Init tokenizer
...@@ -421,7 +425,7 @@ class Scheduler: ...@@ -421,7 +425,7 @@ class Scheduler:
self.abort_request(recv_req) self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightReqInput): elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req) success, message = self.update_weights(recv_req)
self.send_to_detokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
UpdateWeightReqOutput(success, message) UpdateWeightReqOutput(success, message)
) )
elif isinstance(recv_req, ProfileReq): elif isinstance(recv_req, ProfileReq):
...@@ -430,7 +434,7 @@ class Scheduler: ...@@ -430,7 +434,7 @@ class Scheduler:
else: else:
self.stop_profile() self.stop_profile()
elif isinstance(recv_req, GetMemPoolSizeReq): elif isinstance(recv_req, GetMemPoolSizeReq):
self.send_to_detokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
GetMemPoolSizeReqOutput(self.max_total_num_tokens) GetMemPoolSizeReqOutput(self.max_total_num_tokens)
) )
else: else:
......
...@@ -11,16 +11,24 @@ from sglang.test.test_utils import run_mmlu_test ...@@ -11,16 +11,24 @@ from sglang.test.test_utils import run_mmlu_test
class TestOverlapSchedule(unittest.TestCase): class TestOverlapSchedule(unittest.TestCase):
def test_no_radix_attention_chunked_prefill(self): def test_no_radix_attention_chunked_prefill(self):
run_mmlu_test(disable_radix_cache=True, chunked_prefill_size=32) run_mmlu_test(
disable_radix_cache=True, chunked_prefill_size=32, enable_overlap=True
)
def test_no_radix_attention_no_chunked_prefill(self): def test_no_radix_attention_no_chunked_prefill(self):
run_mmlu_test(disable_radix_cache=True, chunked_prefill_size=-1) run_mmlu_test(
disable_radix_cache=True, chunked_prefill_size=-1, enable_overlap=True
)
def test_radix_attention_chunked_prefill(self): def test_radix_attention_chunked_prefill(self):
run_mmlu_test(disable_radix_cache=False, chunked_prefill_size=32) run_mmlu_test(
disable_radix_cache=False, chunked_prefill_size=32, enable_overlap=True
)
def test_radix_attention_no_chunked_prefill(self): def test_radix_attention_no_chunked_prefill(self):
run_mmlu_test(disable_radix_cache=False, chunked_prefill_size=-1) run_mmlu_test(
disable_radix_cache=False, chunked_prefill_size=-1, enable_overlap=True
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment