Unverified Commit 572b4329 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Stop bench CLI from recursively casting all configs to `dict` (#37559)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 9515c208
...@@ -40,9 +40,9 @@ LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more ...@@ -40,9 +40,9 @@ LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more
details. details.
""" """
import dataclasses
import random import random
import time import time
from dataclasses import fields
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
...@@ -124,7 +124,7 @@ def main(args): ...@@ -124,7 +124,7 @@ def main(args):
# Create the LLM engine # Create the LLM engine
engine_args = EngineArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
print("------warm up------") print("------warm up------")
......
...@@ -32,6 +32,7 @@ import dataclasses ...@@ -32,6 +32,7 @@ import dataclasses
import json import json
import random import random
import time import time
from dataclasses import fields
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -196,7 +197,7 @@ def main(args): ...@@ -196,7 +197,7 @@ def main(args):
engine_args = EngineArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, temperature=0,
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
"""Benchmark offline prioritization.""" """Benchmark offline prioritization."""
import argparse import argparse
import dataclasses
import json import json
import random import random
import time import time
from dataclasses import fields
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
...@@ -79,7 +79,7 @@ def run_vllm( ...@@ -79,7 +79,7 @@ def run_vllm(
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
assert all( assert all(
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2]) llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
"""Benchmark the latency of processing a single batch of requests.""" """Benchmark the latency of processing a single batch of requests."""
import argparse import argparse
import dataclasses
import json import json
import os import os
import time import time
from dataclasses import fields
from typing import Any from typing import Any
import numpy as np import numpy as np
...@@ -85,7 +85,7 @@ def main(args: argparse.Namespace): ...@@ -85,7 +85,7 @@ def main(args: argparse.Namespace):
# NOTE(woosuk): If the request cannot be processed in a single batch, # NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
assert llm.llm_engine.model_config.max_model_len >= ( assert llm.llm_engine.model_config.max_model_len >= (
args.input_len + args.output_len args.input_len + args.output_len
), ( ), (
......
...@@ -14,10 +14,10 @@ Run: ...@@ -14,10 +14,10 @@ Run:
""" """
import argparse import argparse
import dataclasses
import json import json
import time import time
from collections import defaultdict from collections import defaultdict
from dataclasses import fields
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
...@@ -225,7 +225,7 @@ def benchmark_multimodal_processor( ...@@ -225,7 +225,7 @@ def benchmark_multimodal_processor(
args.seed = 0 args.seed = 0
engine_args = EngineArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
tokenizer = llm.get_tokenizer() tokenizer = llm.get_tokenizer()
requests = get_requests(args, tokenizer) requests = get_requests(args, tokenizer)
......
...@@ -9,7 +9,6 @@ and cache operations) for both cold and warm scenarios: ...@@ -9,7 +9,6 @@ and cache operations) for both cold and warm scenarios:
""" """
import argparse import argparse
import dataclasses
import json import json
import multiprocessing import multiprocessing
import os import os
...@@ -17,6 +16,7 @@ import shutil ...@@ -17,6 +16,7 @@ import shutil
import tempfile import tempfile
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import fields
from typing import Any from typing import Any
import numpy as np import numpy as np
...@@ -67,7 +67,7 @@ def run_startup_in_subprocess(engine_args, result_queue): ...@@ -67,7 +67,7 @@ def run_startup_in_subprocess(engine_args, result_queue):
# Measure total startup time # Measure total startup time
start_time = time.perf_counter() start_time = time.perf_counter()
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
total_startup_time = time.perf_counter() - start_time total_startup_time = time.perf_counter() - start_time
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse import argparse
import math import math
from dataclasses import asdict, dataclass from dataclasses import dataclass, fields
from pathlib import Path from pathlib import Path
from typing import ClassVar, Literal, get_args from typing import ClassVar, Literal, get_args
...@@ -267,7 +267,7 @@ class SweepServeWorkloadArgs(SweepServeArgs): ...@@ -267,7 +267,7 @@ class SweepServeWorkloadArgs(SweepServeArgs):
base_args = SweepServeArgs.from_cli_args(args) base_args = SweepServeArgs.from_cli_args(args)
return cls( return cls(
**asdict(base_args), **{f.name: getattr(base_args, f.name) for f in fields(base_args)},
workload_var=args.workload_var, workload_var=args.workload_var,
workload_iters=args.workload_iters, workload_iters=args.workload_iters,
) )
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
"""Benchmark offline inference throughput.""" """Benchmark offline inference throughput."""
import argparse import argparse
import dataclasses
import json import json
import os import os
import random import random
import time import time
import warnings import warnings
from dataclasses import fields
from typing import Any from typing import Any
import torch import torch
...@@ -53,7 +53,7 @@ def run_vllm( ...@@ -53,7 +53,7 @@ def run_vllm(
) -> tuple[float, list[RequestOutput] | None]: ) -> tuple[float, list[RequestOutput] | None]:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
assert all( assert all(
llm.llm_engine.model_config.max_model_len llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len) >= (request.prompt_len + request.expected_output_len)
...@@ -141,7 +141,7 @@ def run_vllm_chat( ...@@ -141,7 +141,7 @@ def run_vllm_chat(
""" """
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
assert all( assert all(
llm.llm_engine.model_config.max_model_len llm.llm_engine.model_config.max_model_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment