Unverified Commit 572b4329 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Stop bench CLI from recursively casting all configs to `dict` (#37559)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 9515c208
......@@ -40,9 +40,9 @@ LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more
details.
"""
import dataclasses
import random
import time
from dataclasses import fields
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
......@@ -124,7 +124,7 @@ def main(args):
# Create the LLM engine
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
print("------warm up------")
......
......@@ -32,6 +32,7 @@ import dataclasses
import json
import random
import time
from dataclasses import fields
from transformers import PreTrainedTokenizerBase
......@@ -196,7 +197,7 @@ def main(args):
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
sampling_params = SamplingParams(
temperature=0,
......
......@@ -3,10 +3,10 @@
"""Benchmark offline prioritization."""
import argparse
import dataclasses
import json
import random
import time
from dataclasses import fields
from transformers import AutoTokenizer, PreTrainedTokenizerBase
......@@ -79,7 +79,7 @@ def run_vllm(
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
assert all(
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
......
......@@ -3,10 +3,10 @@
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import dataclasses
import json
import os
import time
from dataclasses import fields
from typing import Any
import numpy as np
......@@ -85,7 +85,7 @@ def main(args: argparse.Namespace):
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
assert llm.llm_engine.model_config.max_model_len >= (
args.input_len + args.output_len
), (
......
......@@ -14,10 +14,10 @@ Run:
"""
import argparse
import dataclasses
import json
import time
from collections import defaultdict
from dataclasses import fields
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal
......@@ -225,7 +225,7 @@ def benchmark_multimodal_processor(
args.seed = 0
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
tokenizer = llm.get_tokenizer()
requests = get_requests(args, tokenizer)
......
......@@ -9,7 +9,6 @@ and cache operations) for both cold and warm scenarios:
"""
import argparse
import dataclasses
import json
import multiprocessing
import os
......@@ -17,6 +16,7 @@ import shutil
import tempfile
import time
from contextlib import contextmanager
from dataclasses import fields
from typing import Any
import numpy as np
......@@ -67,7 +67,7 @@ def run_startup_in_subprocess(engine_args, result_queue):
# Measure total startup time
start_time = time.perf_counter()
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
total_startup_time = time.perf_counter() - start_time
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import math
from dataclasses import asdict, dataclass
from dataclasses import dataclass, fields
from pathlib import Path
from typing import ClassVar, Literal, get_args
......@@ -267,7 +267,7 @@ class SweepServeWorkloadArgs(SweepServeArgs):
base_args = SweepServeArgs.from_cli_args(args)
return cls(
**asdict(base_args),
**{f.name: getattr(base_args, f.name) for f in fields(base_args)},
workload_var=args.workload_var,
workload_iters=args.workload_iters,
)
......
......@@ -3,12 +3,12 @@
"""Benchmark offline inference throughput."""
import argparse
import dataclasses
import json
import os
import random
import time
import warnings
from dataclasses import fields
from typing import Any
import torch
......@@ -53,7 +53,7 @@ def run_vllm(
) -> tuple[float, list[RequestOutput] | None]:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
assert all(
llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
......@@ -141,7 +141,7 @@ def run_vllm_chat(
"""
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
assert all(
llm.llm_engine.model_config.max_model_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment