Unverified Commit 8065a7e2 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Frontend] Add FlexibleArgumentParser to support both underscore and dash in names (#5718)

parent 3f3b6b21
...@@ -13,6 +13,7 @@ from vllm import LLM, SamplingParams ...@@ -13,6 +13,7 @@ from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptStrictInputs from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
...@@ -120,7 +121,7 @@ def main(args: argparse.Namespace): ...@@ -120,7 +121,7 @@ def main(args: argparse.Namespace):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of ' description='Benchmark the latency of processing a single batch of '
'requests till completion.') 'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m') parser.add_argument('--model', type=str, default='facebook/opt-125m')
......
import argparse
import time import time
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
...@@ -44,7 +44,7 @@ def main(args): ...@@ -44,7 +44,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = FlexibleArgumentParser(
description='Benchmark the performance with or without automatic ' description='Benchmark the performance with or without automatic '
'prefix caching.') 'prefix caching.')
parser.add_argument('--model', parser.add_argument('--model',
......
...@@ -44,6 +44,11 @@ try: ...@@ -44,6 +44,11 @@ try:
except ImportError: except ImportError:
from backend_request_func import get_tokenizer from backend_request_func import get_tokenizer
try:
from vllm.utils import FlexibleArgumentParser
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
@dataclass @dataclass
class BenchmarkMetrics: class BenchmarkMetrics:
...@@ -511,7 +516,7 @@ def main(args: argparse.Namespace): ...@@ -511,7 +516,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.") description="Benchmark the online serving throughput.")
parser.add_argument( parser.add_argument(
"--backend", "--backend",
......
...@@ -12,6 +12,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, ...@@ -12,6 +12,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
def sample_requests( def sample_requests(
...@@ -261,7 +262,7 @@ def main(args: argparse.Namespace): ...@@ -261,7 +262,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.") parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", parser.add_argument("--backend",
type=str, type=str,
choices=["vllm", "hf", "mii"], choices=["vllm", "hf", "mii"],
......
...@@ -11,6 +11,7 @@ from torch.utils.benchmark import Measurement as TMeasurement ...@@ -11,6 +11,7 @@ from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:] DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
...@@ -293,7 +294,7 @@ if __name__ == '__main__': ...@@ -293,7 +294,7 @@ if __name__ == '__main__':
return torch.float8_e4m3fn return torch.float8_e4m3fn
raise ValueError("unsupported dtype") raise ValueError("unsupported dtype")
parser = argparse.ArgumentParser( parser = FlexibleArgumentParser(
description=""" description="""
Benchmark Cutlass GEMM. Benchmark Cutlass GEMM.
......
import argparse
import os import os
import sys import sys
from typing import Optional from typing import Optional
...@@ -10,6 +9,7 @@ from vllm import _custom_ops as ops ...@@ -10,6 +9,7 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.aqlm import ( from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype, dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm) optimized_dequantize_gemm)
from vllm.utils import FlexibleArgumentParser
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
...@@ -137,7 +137,7 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: ...@@ -137,7 +137,7 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
def main(): def main():
parser = argparse.ArgumentParser(description="Benchmark aqlm performance.") parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
# Add arguments # Add arguments
parser.add_argument("--nbooks", parser.add_argument("--nbooks",
......
import argparse
from typing import List from typing import List
import torch import torch
...@@ -16,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -16,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, marlin_24_quantize, marlin_quantize) MarlinWorkspace, marlin_24_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights) gptq_pack, quantize_weights, sort_weights)
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
...@@ -211,7 +211,7 @@ def main(args): ...@@ -211,7 +211,7 @@ def main(args):
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501 # python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
# #
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches") description="Benchmark Marlin across specified models/shapes/batches")
parser.add_argument( parser.add_argument(
"--models", "--models",
......
...@@ -10,6 +10,7 @@ from ray.experimental.tqdm_ray import tqdm ...@@ -10,6 +10,7 @@ from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.utils import FlexibleArgumentParser
class BenchmarkConfig(TypedDict): class BenchmarkConfig(TypedDict):
...@@ -315,7 +316,7 @@ def main(args: argparse.Namespace): ...@@ -315,7 +316,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = FlexibleArgumentParser()
parser.add_argument("--model", parser.add_argument("--model",
type=str, type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1") default="mistralai/Mixtral-8x7B-Instruct-v0.1")
......
import argparse
import random import random
import time import time
from typing import List, Optional from typing import List, Optional
...@@ -6,7 +5,8 @@ from typing import List, Optional ...@@ -6,7 +5,8 @@ from typing import List, Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)
NUM_BLOCKS = 1024 NUM_BLOCKS = 1024
PARTITION_SIZE = 512 PARTITION_SIZE = 512
...@@ -161,7 +161,7 @@ def main( ...@@ -161,7 +161,7 @@ def main(
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the paged attention kernel.") description="Benchmark the paged attention kernel.")
parser.add_argument("--version", parser.add_argument("--version",
type=str, type=str,
......
import argparse
from itertools import accumulate from itertools import accumulate
from typing import List, Optional from typing import List, Optional
...@@ -7,6 +6,7 @@ import torch ...@@ -7,6 +6,7 @@ import torch
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope) get_rope)
from vllm.utils import FlexibleArgumentParser
def benchmark_rope_kernels_multi_lora( def benchmark_rope_kernels_multi_lora(
...@@ -86,7 +86,7 @@ def benchmark_rope_kernels_multi_lora( ...@@ -86,7 +86,7 @@ def benchmark_rope_kernels_multi_lora(
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the rotary embedding kernels.") description="Benchmark the rotary embedding kernels.")
parser.add_argument("--is-neox-style", type=bool, default=True) parser.add_argument("--is-neox-style", type=bool, default=True)
parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--batch-size", type=int, default=16)
......
import argparse
import cProfile import cProfile
import pstats import pstats
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser
# A very long prompt, total number of tokens is about 15k. # A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?" LONG_PROMPT = ["You are an expert in large language models, aren't you?"
...@@ -47,7 +47,7 @@ def main(args): ...@@ -47,7 +47,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = FlexibleArgumentParser(
description='Benchmark the performance of hashing function in' description='Benchmark the performance of hashing function in'
'automatic prefix caching.') 'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
......
import argparse
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser
def main(): def main():
parser = argparse.ArgumentParser(description='AQLM examples') parser = FlexibleArgumentParser(description='AQLM examples')
parser.add_argument('--model', parser.add_argument('--model',
'-m', '-m',
......
...@@ -2,6 +2,7 @@ import argparse ...@@ -2,6 +2,7 @@ import argparse
from typing import List, Tuple from typing import List, Tuple
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.utils import FlexibleArgumentParser
def create_test_prompts() -> List[Tuple[str, SamplingParams]]: def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
...@@ -55,7 +56,7 @@ def main(args: argparse.Namespace): ...@@ -55,7 +56,7 @@ def main(args: argparse.Namespace):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using the LLMEngine class directly') description='Demo on using the LLMEngine class directly')
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -20,15 +20,15 @@ llm = LLM( ...@@ -20,15 +20,15 @@ llm = LLM(
tensor_parallel_size=8, tensor_parallel_size=8,
) )
""" """
import argparse
import dataclasses import dataclasses
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
parser = argparse.ArgumentParser() parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser) EngineArgs.add_cli_args(parser)
parser.add_argument("--output", parser.add_argument("--output",
"-o", "-o",
......
...@@ -9,6 +9,7 @@ from vllm.engine.arg_utils import EngineArgs ...@@ -9,6 +9,7 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs, from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
TensorizerConfig, TensorizerConfig,
tensorize_vllm_model) tensorize_vllm_model)
from vllm.utils import FlexibleArgumentParser
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring
# yapf: disable # yapf: disable
...@@ -96,7 +97,7 @@ deserialization in this example script, although `--tensorizer-uri` and ...@@ -96,7 +97,7 @@ deserialization in this example script, although `--tensorizer-uri` and
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = FlexibleArgumentParser(
description="An example script that can be used to serialize and " description="An example script that can be used to serialize and "
"deserialize vLLM models. These models " "deserialize vLLM models. These models "
"can be loaded using tensorizer directly to the GPU " "can be loaded using tensorizer directly to the GPU "
......
"""vllm.entrypoints.api_server with some extra logging for testing.""" """vllm.entrypoints.api_server with some extra logging for testing."""
import argparse
from typing import Any, Dict from typing import Any, Dict
import uvicorn import uvicorn
...@@ -8,6 +7,7 @@ from fastapi.responses import JSONResponse, Response ...@@ -8,6 +7,7 @@ from fastapi.responses import JSONResponse, Response
import vllm.entrypoints.api_server import vllm.entrypoints.api_server
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.utils import FlexibleArgumentParser
app = vllm.entrypoints.api_server.app app = vllm.entrypoints.api_server.app
...@@ -33,7 +33,7 @@ def stats() -> Response: ...@@ -33,7 +33,7 @@ def stats() -> Response:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
......
...@@ -11,7 +11,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, ...@@ -11,7 +11,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
SpeculativeConfig, TokenizerPoolConfig, SpeculativeConfig, TokenizerPoolConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple from vllm.utils import FlexibleArgumentParser, str_to_int_tuple
def nullable_str(val: str): def nullable_str(val: str):
...@@ -110,7 +110,7 @@ class EngineArgs: ...@@ -110,7 +110,7 @@ class EngineArgs:
@staticmethod @staticmethod
def add_cli_args_for_vlm( def add_cli_args_for_vlm(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--image-input-type', parser.add_argument('--image-input-type',
type=nullable_str, type=nullable_str,
default=None, default=None,
...@@ -156,8 +156,7 @@ class EngineArgs: ...@@ -156,8 +156,7 @@ class EngineArgs:
return parser return parser
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
# Model arguments # Model arguments
...@@ -800,8 +799,8 @@ class AsyncEngineArgs(EngineArgs): ...@@ -800,8 +799,8 @@ class AsyncEngineArgs(EngineArgs):
max_log_len: Optional[int] = None max_log_len: Optional[int] = None
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser, def add_cli_args(parser: FlexibleArgumentParser,
async_args_only: bool = False) -> argparse.ArgumentParser: async_args_only: bool = False) -> FlexibleArgumentParser:
if not async_args_only: if not async_args_only:
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray', parser.add_argument('--engine-use-ray',
...@@ -822,13 +821,13 @@ class AsyncEngineArgs(EngineArgs): ...@@ -822,13 +821,13 @@ class AsyncEngineArgs(EngineArgs):
# These functions are used by sphinx to build the documentation # These functions are used by sphinx to build the documentation
def _engine_args_parser(): def _engine_args_parser():
return EngineArgs.add_cli_args(argparse.ArgumentParser()) return EngineArgs.add_cli_args(FlexibleArgumentParser())
def _async_engine_args_parser(): def _async_engine_args_parser():
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(), return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
async_args_only=True) async_args_only=True)
def _vlm_engine_args_parser(): def _vlm_engine_args_parser():
return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser()) return EngineArgs.add_cli_args_for_vlm(FlexibleArgumentParser())
...@@ -6,7 +6,6 @@ We are also not going to accept PRs modifying this file, please ...@@ -6,7 +6,6 @@ We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead. change `vllm/entrypoints/openai/api_server.py` instead.
""" """
import argparse
import json import json
import ssl import ssl
from typing import AsyncGenerator from typing import AsyncGenerator
...@@ -19,7 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -19,7 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid from vllm.utils import FlexibleArgumentParser, random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI() app = FastAPI()
...@@ -80,7 +79,7 @@ async def generate(request: Request) -> Response: ...@@ -80,7 +79,7 @@ async def generate(request: Request) -> Response:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None) parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-keyfile", type=str, default=None)
......
...@@ -10,6 +10,7 @@ import ssl ...@@ -10,6 +10,7 @@ import ssl
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.utils import FlexibleArgumentParser
class LoRAParserAction(argparse.Action): class LoRAParserAction(argparse.Action):
...@@ -23,7 +24,7 @@ class LoRAParserAction(argparse.Action): ...@@ -23,7 +24,7 @@ class LoRAParserAction(argparse.Action):
def make_arg_parser(): def make_arg_parser():
parser = argparse.ArgumentParser( parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.") description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", parser.add_argument("--host",
type=nullable_str, type=nullable_str,
......
import argparse
import asyncio import asyncio
import sys import sys
from io import StringIO from io import StringIO
...@@ -16,14 +15,14 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput, ...@@ -16,14 +15,14 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible batch runner.") description="vLLM OpenAI-Compatible batch runner.")
parser.add_argument( parser.add_argument(
"-i", "-i",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment