Unverified Commit 8f2c522a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve benchmark scripts and error message printing (#2922)

parent 75964177
......@@ -39,14 +39,15 @@ class BenchArgs:
dataset_path: str = ""
num_prompts: int = 1000
sharegpt_output_len: Optional[int] = None
sharegpt_context_len: Optional[int] = None
random_input_len: int = 1024
random_output_len: int = 1024
random_range_ratio: float = 0.0
gen_num_groups: int = 64
gen_prompts_per_group: int = 16
gen_system_prompt_len: int = 2048
gen_question_len: int = 128
gen_output_len: int = 256
gsp_num_groups: int = 64
gsp_prompts_per_group: int = 16
gsp_system_prompt_len: int = 2048
gsp_question_len: int = 128
gsp_output_len: int = 256
disable_ignore_eos: bool = False
extra_request_body: Optional[str] = None
seed: int = 1
......@@ -82,6 +83,12 @@ class BenchArgs:
default=BenchArgs.sharegpt_output_len,
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
)
parser.add_argument(
"--sharegpt-context-len",
type=int,
default=BenchArgs.sharegpt_context_len,
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
)
parser.add_argument(
"--random-input-len",
type=int,
......@@ -102,35 +109,35 @@ class BenchArgs:
"used only for random dataset.",
)
parser.add_argument(
"--gen-num-groups",
"--gsp-num-groups",
type=int,
default=BenchArgs.gen_num_groups,
default=BenchArgs.gsp_num_groups,
help="Number of groups with shared prefix, used"
"only for generate-shared-prefix",
)
parser.add_argument(
"--gen-prompts-per-group",
"--gsp-prompts-per-group",
type=int,
default=BenchArgs.gen_prompts_per_group,
default=BenchArgs.gsp_prompts_per_group,
help="Number of prompts per group of shared prefix, used"
"only for generate-shared-prefix",
)
parser.add_argument(
"--gen-system-prompt-len",
"--gsp-system-prompt-len",
type=int,
default=BenchArgs.gen_system_prompt_len,
default=BenchArgs.gsp_system_prompt_len,
help="System prompt length, used" "only for generate-shared-prefix",
)
parser.add_argument(
"--gen-question-len",
"--gsp-question-len",
type=int,
default=BenchArgs.gen_question_len,
default=BenchArgs.gsp_question_len,
help="Question length, used" "only for generate-shared-prefix",
)
parser.add_argument(
"--gen-output-len",
"--gsp-output-len",
type=int,
default=BenchArgs.gen_output_len,
default=BenchArgs.gsp_output_len,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
parser.add_argument(
......
......@@ -452,6 +452,7 @@ def get_dataset(args, tokenizer):
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
context_len=args.sharegpt_context_len,
)
elif args.dataset_name == "random":
input_requests = sample_random_requests(
......@@ -464,11 +465,11 @@ def get_dataset(args, tokenizer):
)
elif args.dataset_name == "generated-shared-prefix":
input_requests = sample_generated_shared_prefix_requests(
num_groups=args.gen_num_groups,
prompts_per_group=args.gen_prompts_per_group,
system_prompt_len=args.gen_system_prompt_len,
question_len=args.gen_question_len,
output_len=args.gen_output_len,
num_groups=args.gsp_num_groups,
prompts_per_group=args.gsp_prompts_per_group,
system_prompt_len=args.gsp_system_prompt_len,
question_len=args.gsp_question_len,
output_len=args.gsp_output_len,
tokenizer=tokenizer,
)
else:
......@@ -560,6 +561,7 @@ def sample_sharegpt_requests(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
context_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
......@@ -597,14 +599,15 @@ def sample_sharegpt_requests(
output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 4 or output_len < 4:
if prompt_len < 1 or output_len < 1:
# Prune too short sequences.
continue
if prompt_len > 1024 or (
prompt_len + output_len > 2048 and fixed_output_len is None
):
if context_len and prompt_len + output_len > context_len:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
......@@ -706,8 +709,8 @@ def get_gen_prefix_cache_path(args, tokenizer):
# Create a unique cache filename based on the generation parameters
cache_key = (
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
f"{tokenizer.__class__.__name__}.pkl"
)
return cache_dir / cache_key
......@@ -1374,6 +1377,12 @@ if __name__ == "__main__":
default=None,
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
)
parser.add_argument(
"--sharegpt-context-len",
type=int,
default=None,
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
)
parser.add_argument(
"--random-input-len",
type=int,
......@@ -1453,49 +1462,49 @@ if __name__ == "__main__":
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--lora-name",
type=str,
default=None,
help="The name of LoRA adapter",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument(
"--gen-num-groups",
"--gsp-num-groups",
type=int,
default=64,
help="Number of system prompt groups for generated-shared-prefix dataset",
)
group.add_argument(
"--gen-prompts-per-group",
"--gsp-prompts-per-group",
type=int,
default=16,
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
)
group.add_argument(
"--gen-system-prompt-len",
"--gsp-system-prompt-len",
type=int,
default=2048,
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
)
group.add_argument(
"--gen-question-len",
"--gsp-question-len",
type=int,
default=128,
help="Target length in tokens for questions in generated-shared-prefix dataset",
)
group.add_argument(
"--gen-output-len",
"--gsp-output-len",
type=int,
default=256,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--lora-name",
type=str,
default=None,
help="The name of LoRA adapter",
)
args = parser.parse_args()
run_benchmark(args)
......@@ -59,6 +59,9 @@ class GenerateReqInput:
return_text_in_logprobs: bool = False
# Whether to stream output.
stream: bool = False
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
log_metrics: bool = True
# The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None
# LoRA related
......@@ -196,6 +199,7 @@ class GenerateReqInput:
top_logprobs_num=self.top_logprobs_num[i],
return_text_in_logprobs=self.return_text_in_logprobs,
stream=self.stream,
log_metrics=self.log_metrics,
modalities=self.modalities[i] if self.modalities else None,
lora_path=self.lora_path[i] if self.lora_path is not None else None,
)
......@@ -243,6 +247,8 @@ class EmbeddingReqInput:
sampling_params: Union[List[Dict], Dict] = None
# Dummy input embeds for compatibility
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
log_metrics: bool = True
def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
......
......@@ -631,7 +631,8 @@ class Scheduler:
if len(req.origin_input_ids) > self.max_req_input_len:
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
"the max context length. Truncated. "
f"{len(req.origin_input_ids)=}, {self.max_req_input_len=}."
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
......
......@@ -79,6 +79,7 @@ from sglang.srt.utils import (
get_zmq_socket,
kill_process_tree,
)
from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......@@ -640,7 +641,9 @@ class TokenizerManager:
self.to_create_loop = False
loop = asyncio.get_event_loop()
self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.handle_loop))
)
# We cannot add signal handler when the tokenizer manager is not in
# the main thread due to the CPython limitation.
......@@ -653,7 +656,9 @@ class TokenizerManager:
"not in the main thread. This disables graceful shutdown of the "
"tokenizer manager when SIGTERM is received."
)
self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
)
async def sigterm_watchdog(self):
while not self.gracefully_exit:
......@@ -738,9 +743,13 @@ class TokenizerManager:
state.finished = recv_obj.finished_reasons[i] is not None
state.event.set()
if self.enable_metrics:
if self.enable_metrics and state.obj.log_metrics:
self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished:
if (
self.dump_requests_folder
and state.finished
and state.obj.log_metrics
):
self.dump_requests(state, out_dict)
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
......@@ -887,20 +896,38 @@ class TokenizerManager:
)
if len(self.dump_request_list) >= self.dump_requests_threshold:
filename = os.path.join(
self.dump_requests_folder,
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
)
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
to_dump = self.dump_request_list
self.dump_request_list = []
def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True)
current_time = datetime.now()
filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
with open(filename, "wb") as f:
pickle.dump(to_dump, f)
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
async def print_exception_wrapper(func):
"""
Sometimes an asyncio function does not print exception.
We do another wrapper to handle the exception.
"""
try:
await func()
except Exception:
traceback = get_exception_traceback()
logger.error(f"TokenizerManager hit an exception: {traceback}")
kill_process_tree(os.getpid(), include_parent=True)
sys.exit(1)
class SignalHandler:
def __init__(self, tokenizer_manager):
self.tokenizer_manager = tokenizer_manager
......
......@@ -135,9 +135,13 @@ async def health_generate(request: Request) -> Response:
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
if tokenizer_manager.is_generation:
gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params)
gri = GenerateReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False
)
else:
gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params)
gri = EmbeddingReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False
)
try:
async for _ in tokenizer_manager.generate_request(gri, request):
......
......@@ -560,6 +560,7 @@ def run_bench_serving(
tokenizer=tokenizer,
num_prompts=num_prompts,
sharegpt_output_len=None,
sharegpt_context_len=None,
random_input_len=random_input_len,
random_output_len=random_output_len,
random_range_ratio=0.0,
......
......@@ -44,7 +44,7 @@ class TestEpMoE(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.5
self.assertGreater(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
......@@ -56,7 +56,7 @@ class TestEpMoE(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.8
self.assertGreater(metrics["score"], 0.8)
class TestEpMoEFP8(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment