Unverified Commit 8f2c522a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve benchmark scripts and error message printing (#2922)

parent 75964177
...@@ -39,14 +39,15 @@ class BenchArgs: ...@@ -39,14 +39,15 @@ class BenchArgs:
dataset_path: str = "" dataset_path: str = ""
num_prompts: int = 1000 num_prompts: int = 1000
sharegpt_output_len: Optional[int] = None sharegpt_output_len: Optional[int] = None
sharegpt_context_len: Optional[int] = None
random_input_len: int = 1024 random_input_len: int = 1024
random_output_len: int = 1024 random_output_len: int = 1024
random_range_ratio: float = 0.0 random_range_ratio: float = 0.0
gen_num_groups: int = 64 gsp_num_groups: int = 64
gen_prompts_per_group: int = 16 gsp_prompts_per_group: int = 16
gen_system_prompt_len: int = 2048 gsp_system_prompt_len: int = 2048
gen_question_len: int = 128 gsp_question_len: int = 128
gen_output_len: int = 256 gsp_output_len: int = 256
disable_ignore_eos: bool = False disable_ignore_eos: bool = False
extra_request_body: Optional[str] = None extra_request_body: Optional[str] = None
seed: int = 1 seed: int = 1
...@@ -82,6 +83,12 @@ class BenchArgs: ...@@ -82,6 +83,12 @@ class BenchArgs:
default=BenchArgs.sharegpt_output_len, default=BenchArgs.sharegpt_output_len,
help="Output length for each request. Overrides the output length from the ShareGPT dataset.", help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
) )
parser.add_argument(
"--sharegpt-context-len",
type=int,
default=BenchArgs.sharegpt_context_len,
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
)
parser.add_argument( parser.add_argument(
"--random-input-len", "--random-input-len",
type=int, type=int,
...@@ -102,35 +109,35 @@ class BenchArgs: ...@@ -102,35 +109,35 @@ class BenchArgs:
"used only for random dataset.", "used only for random dataset.",
) )
parser.add_argument( parser.add_argument(
"--gen-num-groups", "--gsp-num-groups",
type=int, type=int,
default=BenchArgs.gen_num_groups, default=BenchArgs.gsp_num_groups,
help="Number of groups with shared prefix, used" help="Number of groups with shared prefix, used"
"only for generate-shared-prefix", "only for generate-shared-prefix",
) )
parser.add_argument( parser.add_argument(
"--gen-prompts-per-group", "--gsp-prompts-per-group",
type=int, type=int,
default=BenchArgs.gen_prompts_per_group, default=BenchArgs.gsp_prompts_per_group,
help="Number of prompts per group of shared prefix, used" help="Number of prompts per group of shared prefix, used"
"only for generate-shared-prefix", "only for generate-shared-prefix",
) )
parser.add_argument( parser.add_argument(
"--gen-system-prompt-len", "--gsp-system-prompt-len",
type=int, type=int,
default=BenchArgs.gen_system_prompt_len, default=BenchArgs.gsp_system_prompt_len,
help="System prompt length, used" "only for generate-shared-prefix", help="System prompt length, used" "only for generate-shared-prefix",
) )
parser.add_argument( parser.add_argument(
"--gen-question-len", "--gsp-question-len",
type=int, type=int,
default=BenchArgs.gen_question_len, default=BenchArgs.gsp_question_len,
help="Question length, used" "only for generate-shared-prefix", help="Question length, used" "only for generate-shared-prefix",
) )
parser.add_argument( parser.add_argument(
"--gen-output-len", "--gsp-output-len",
type=int, type=int,
default=BenchArgs.gen_output_len, default=BenchArgs.gsp_output_len,
help="Target length in tokens for outputs in generated-shared-prefix dataset", help="Target length in tokens for outputs in generated-shared-prefix dataset",
) )
parser.add_argument( parser.add_argument(
......
...@@ -452,6 +452,7 @@ def get_dataset(args, tokenizer): ...@@ -452,6 +452,7 @@ def get_dataset(args, tokenizer):
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len, fixed_output_len=args.sharegpt_output_len,
context_len=args.sharegpt_context_len,
) )
elif args.dataset_name == "random": elif args.dataset_name == "random":
input_requests = sample_random_requests( input_requests = sample_random_requests(
...@@ -464,11 +465,11 @@ def get_dataset(args, tokenizer): ...@@ -464,11 +465,11 @@ def get_dataset(args, tokenizer):
) )
elif args.dataset_name == "generated-shared-prefix": elif args.dataset_name == "generated-shared-prefix":
input_requests = sample_generated_shared_prefix_requests( input_requests = sample_generated_shared_prefix_requests(
num_groups=args.gen_num_groups, num_groups=args.gsp_num_groups,
prompts_per_group=args.gen_prompts_per_group, prompts_per_group=args.gsp_prompts_per_group,
system_prompt_len=args.gen_system_prompt_len, system_prompt_len=args.gsp_system_prompt_len,
question_len=args.gen_question_len, question_len=args.gsp_question_len,
output_len=args.gen_output_len, output_len=args.gsp_output_len,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
else: else:
...@@ -560,6 +561,7 @@ def sample_sharegpt_requests( ...@@ -560,6 +561,7 @@ def sample_sharegpt_requests(
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None, fixed_output_len: Optional[int] = None,
context_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4: if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small") raise ValueError("output_len too small")
...@@ -597,14 +599,15 @@ def sample_sharegpt_requests( ...@@ -597,14 +599,15 @@ def sample_sharegpt_requests(
output_len = ( output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len len(completion_token_ids) if fixed_output_len is None else fixed_output_len
) )
if prompt_len < 4 or output_len < 4:
if prompt_len < 1 or output_len < 1:
# Prune too short sequences. # Prune too short sequences.
continue continue
if prompt_len > 1024 or (
prompt_len + output_len > 2048 and fixed_output_len is None if context_len and prompt_len + output_len > context_len:
):
# Prune too long sequences. # Prune too long sequences.
continue continue
filtered_dataset.append((prompt, prompt_len, output_len)) filtered_dataset.append((prompt, prompt_len, output_len))
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}") print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
...@@ -706,8 +709,8 @@ def get_gen_prefix_cache_path(args, tokenizer): ...@@ -706,8 +709,8 @@ def get_gen_prefix_cache_path(args, tokenizer):
# Create a unique cache filename based on the generation parameters # Create a unique cache filename based on the generation parameters
cache_key = ( cache_key = (
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_" f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_" f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
f"{tokenizer.__class__.__name__}.pkl" f"{tokenizer.__class__.__name__}.pkl"
) )
return cache_dir / cache_key return cache_dir / cache_key
...@@ -1374,6 +1377,12 @@ if __name__ == "__main__": ...@@ -1374,6 +1377,12 @@ if __name__ == "__main__":
default=None, default=None,
help="Output length for each request. Overrides the output length from the ShareGPT dataset.", help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
) )
parser.add_argument(
"--sharegpt-context-len",
type=int,
default=None,
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
)
parser.add_argument( parser.add_argument(
"--random-input-len", "--random-input-len",
type=int, type=int,
...@@ -1453,49 +1462,49 @@ if __name__ == "__main__": ...@@ -1453,49 +1462,49 @@ if __name__ == "__main__":
help="Append given JSON object to the request payload. You can use this to specify" help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.", "additional generate params like sampling params.",
) )
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--lora-name",
type=str,
default=None,
help="The name of LoRA adapter",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments") group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument( group.add_argument(
"--gen-num-groups", "--gsp-num-groups",
type=int, type=int,
default=64, default=64,
help="Number of system prompt groups for generated-shared-prefix dataset", help="Number of system prompt groups for generated-shared-prefix dataset",
) )
group.add_argument( group.add_argument(
"--gen-prompts-per-group", "--gsp-prompts-per-group",
type=int, type=int,
default=16, default=16,
help="Number of prompts per system prompt group for generated-shared-prefix dataset", help="Number of prompts per system prompt group for generated-shared-prefix dataset",
) )
group.add_argument( group.add_argument(
"--gen-system-prompt-len", "--gsp-system-prompt-len",
type=int, type=int,
default=2048, default=2048,
help="Target length in tokens for system prompts in generated-shared-prefix dataset", help="Target length in tokens for system prompts in generated-shared-prefix dataset",
) )
group.add_argument( group.add_argument(
"--gen-question-len", "--gsp-question-len",
type=int, type=int,
default=128, default=128,
help="Target length in tokens for questions in generated-shared-prefix dataset", help="Target length in tokens for questions in generated-shared-prefix dataset",
) )
group.add_argument( group.add_argument(
"--gen-output-len", "--gsp-output-len",
type=int, type=int,
default=256, default=256,
help="Target length in tokens for outputs in generated-shared-prefix dataset", help="Target length in tokens for outputs in generated-shared-prefix dataset",
) )
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--lora-name",
type=str,
default=None,
help="The name of LoRA adapter",
)
args = parser.parse_args() args = parser.parse_args()
run_benchmark(args) run_benchmark(args)
...@@ -59,6 +59,9 @@ class GenerateReqInput: ...@@ -59,6 +59,9 @@ class GenerateReqInput:
return_text_in_logprobs: bool = False return_text_in_logprobs: bool = False
# Whether to stream output. # Whether to stream output.
stream: bool = False stream: bool = False
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
log_metrics: bool = True
# The modalities of the image data [image, multi-images, video] # The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
# LoRA related # LoRA related
...@@ -196,6 +199,7 @@ class GenerateReqInput: ...@@ -196,6 +199,7 @@ class GenerateReqInput:
top_logprobs_num=self.top_logprobs_num[i], top_logprobs_num=self.top_logprobs_num[i],
return_text_in_logprobs=self.return_text_in_logprobs, return_text_in_logprobs=self.return_text_in_logprobs,
stream=self.stream, stream=self.stream,
log_metrics=self.log_metrics,
modalities=self.modalities[i] if self.modalities else None, modalities=self.modalities[i] if self.modalities else None,
lora_path=self.lora_path[i] if self.lora_path is not None else None, lora_path=self.lora_path[i] if self.lora_path is not None else None,
) )
...@@ -243,6 +247,8 @@ class EmbeddingReqInput: ...@@ -243,6 +247,8 @@ class EmbeddingReqInput:
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
# Dummy input embeds for compatibility # Dummy input embeds for compatibility
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
log_metrics: bool = True
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or ( if (self.text is None and self.input_ids is None) or (
......
...@@ -631,7 +631,8 @@ class Scheduler: ...@@ -631,7 +631,8 @@ class Scheduler:
if len(req.origin_input_ids) > self.max_req_input_len: if len(req.origin_input_ids) > self.max_req_input_len:
logger.warning( logger.warning(
"Request length is longer than the KV cache pool size or " "Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!" "the max context length. Truncated. "
f"{len(req.origin_input_ids)=}, {self.max_req_input_len=}."
) )
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
......
...@@ -79,6 +79,7 @@ from sglang.srt.utils import ( ...@@ -79,6 +79,7 @@ from sglang.srt.utils import (
get_zmq_socket, get_zmq_socket,
kill_process_tree, kill_process_tree,
) )
from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -640,7 +641,9 @@ class TokenizerManager: ...@@ -640,7 +641,9 @@ class TokenizerManager:
self.to_create_loop = False self.to_create_loop = False
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self.asyncio_tasks.add(loop.create_task(self.handle_loop())) self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.handle_loop))
)
# We cannot add signal handler when the tokenizer manager is not in # We cannot add signal handler when the tokenizer manager is not in
# the main thread due to the CPython limitation. # the main thread due to the CPython limitation.
...@@ -653,7 +656,9 @@ class TokenizerManager: ...@@ -653,7 +656,9 @@ class TokenizerManager:
"not in the main thread. This disables graceful shutdown of the " "not in the main thread. This disables graceful shutdown of the "
"tokenizer manager when SIGTERM is received." "tokenizer manager when SIGTERM is received."
) )
self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog())) self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
)
async def sigterm_watchdog(self): async def sigterm_watchdog(self):
while not self.gracefully_exit: while not self.gracefully_exit:
...@@ -738,9 +743,13 @@ class TokenizerManager: ...@@ -738,9 +743,13 @@ class TokenizerManager:
state.finished = recv_obj.finished_reasons[i] is not None state.finished = recv_obj.finished_reasons[i] is not None
state.event.set() state.event.set()
if self.enable_metrics: if self.enable_metrics and state.obj.log_metrics:
self.collect_metrics(state, recv_obj, i) self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished: if (
self.dump_requests_folder
and state.finished
and state.obj.log_metrics
):
self.dump_requests(state, out_dict) self.dump_requests(state, out_dict)
elif isinstance(recv_obj, OpenSessionReqOutput): elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result( self.session_futures[recv_obj.session_id].set_result(
...@@ -887,20 +896,38 @@ class TokenizerManager: ...@@ -887,20 +896,38 @@ class TokenizerManager:
) )
if len(self.dump_request_list) >= self.dump_requests_threshold: if len(self.dump_request_list) >= self.dump_requests_threshold:
filename = os.path.join(
self.dump_requests_folder,
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
)
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
to_dump = self.dump_request_list to_dump = self.dump_request_list
self.dump_request_list = [] self.dump_request_list = []
def background_task(): def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True) os.makedirs(self.dump_requests_folder, exist_ok=True)
current_time = datetime.now() with open(filename, "wb") as f:
filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
pickle.dump(to_dump, f) pickle.dump(to_dump, f)
# Schedule the task to run in the background without awaiting it # Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task)) asyncio.create_task(asyncio.to_thread(background_task))
async def print_exception_wrapper(func):
"""
Sometimes an asyncio function does not print exception.
We do another wrapper to handle the exception.
"""
try:
await func()
except Exception:
traceback = get_exception_traceback()
logger.error(f"TokenizerManager hit an exception: {traceback}")
kill_process_tree(os.getpid(), include_parent=True)
sys.exit(1)
class SignalHandler: class SignalHandler:
def __init__(self, tokenizer_manager): def __init__(self, tokenizer_manager):
self.tokenizer_manager = tokenizer_manager self.tokenizer_manager = tokenizer_manager
......
...@@ -135,9 +135,13 @@ async def health_generate(request: Request) -> Response: ...@@ -135,9 +135,13 @@ async def health_generate(request: Request) -> Response:
sampling_params = {"max_new_tokens": 1, "temperature": 0.7} sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
if tokenizer_manager.is_generation: if tokenizer_manager.is_generation:
gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params) gri = GenerateReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False
)
else: else:
gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params) gri = EmbeddingReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False
)
try: try:
async for _ in tokenizer_manager.generate_request(gri, request): async for _ in tokenizer_manager.generate_request(gri, request):
......
...@@ -560,6 +560,7 @@ def run_bench_serving( ...@@ -560,6 +560,7 @@ def run_bench_serving(
tokenizer=tokenizer, tokenizer=tokenizer,
num_prompts=num_prompts, num_prompts=num_prompts,
sharegpt_output_len=None, sharegpt_output_len=None,
sharegpt_context_len=None,
random_input_len=random_input_len, random_input_len=random_input_len,
random_output_len=random_output_len, random_output_len=random_output_len,
random_range_ratio=0.0, random_range_ratio=0.0,
......
...@@ -44,7 +44,7 @@ class TestEpMoE(unittest.TestCase): ...@@ -44,7 +44,7 @@ class TestEpMoE(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.5 self.assertGreater(metrics["score"], 0.5)
def test_mgsm_en(self): def test_mgsm_en(self):
args = SimpleNamespace( args = SimpleNamespace(
...@@ -56,7 +56,7 @@ class TestEpMoE(unittest.TestCase): ...@@ -56,7 +56,7 @@ class TestEpMoE(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.8 self.assertGreater(metrics["score"], 0.8)
class TestEpMoEFP8(unittest.TestCase): class TestEpMoEFP8(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment