Unverified Commit e51929eb authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Improve configs - `SchedulerConfig` (#16533)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent dc1b4a6f
......@@ -1522,6 +1522,9 @@ class LoadConfig:
self.ignore_patterns = ["original/**/*"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
@config
@dataclass
class ParallelConfig:
......@@ -1563,7 +1566,7 @@ class ParallelConfig:
placement_group: Optional["PlacementGroup"] = None
"""ray distributed model workers placement group."""
distributed_executor_backend: Optional[Union[str,
distributed_executor_backend: Optional[Union[DistributedExecutorBackend,
type["ExecutorBase"]]] = None
"""Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product
......@@ -1687,7 +1690,7 @@ class ParallelConfig:
# current node and we aren't in a ray placement group.
from vllm.executor import ray_utils
backend = "mp"
backend: DistributedExecutorBackend = "mp"
ray_found = ray_utils.ray_is_available()
if current_platform.is_neuron():
# neuron uses single process to control multiple devices
......@@ -1755,92 +1758,124 @@ class ParallelConfig:
"worker_extension_cls must be a string (qualified class name).")
SchedulerPolicy = Literal["fcfs", "priority"]
@config
@dataclass
class SchedulerConfig:
"""Scheduler configuration."""
runner_type: str = "generate" # The runner type to launch for the model.
runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""
# Maximum number of tokens to be processed in a single iteration.
max_num_batched_tokens: int = field(default=None) # type: ignore
max_num_batched_tokens: int = None # type: ignore
"""Maximum number of tokens to be processed in a single iteration.
# Maximum number of sequences to be processed in a single iteration.
max_num_seqs: int = 128
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
# Maximum length of a sequence (including prompt and generated text).
max_model_len: int = 8192
max_num_seqs: int = None # type: ignore
"""Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
max_model_len: int = None # type: ignore
"""Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
# Maximum number of sequences that can be partially prefilled concurrently
max_num_partial_prefills: int = 1
"""For chunked prefill, the maximum number of sequences that can be
partially prefilled concurrently."""
# Maximum number of "very long prompt" sequences that can be prefilled
# concurrently (long is defined by long_prefill_threshold)
max_long_partial_prefills: int = 1
"""For chunked prefill, the maximum number of prompts longer than
long_prefill_token_threshold that will be prefilled concurrently. Setting
this less than max_num_partial_prefills will allow shorter prompts to jump
the queue in front of longer prompts in some cases, improving latency."""
# calculate context length that determines which sequences are
# considered "long"
long_prefill_token_threshold: int = 0
"""For chunked prefill, a request is considered long if the prompt is
longer than this number of tokens."""
# The number of slots to allocate per sequence per
# step, beyond the known token ids. This is used in speculative
# decoding to store KV activations of tokens which may or may not be
# accepted.
num_lookahead_slots: int = 0
"""The number of slots to allocate per sequence per
step, beyond the known token ids. This is used in speculative
decoding to store KV activations of tokens which may or may not be
accepted.
NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then."""
# Apply a delay (of delay factor multiplied by previous
# prompt latency) before scheduling next prompt.
delay_factor: float = 0.0
"""Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt."""
# If True, prefill requests can be chunked based
# on the remaining max_num_batched_tokens.
enable_chunked_prefill: bool = False
enable_chunked_prefill: bool = None # type: ignore
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
is_multimodal_model: bool = False
"""True if the model is multimodal."""
# TODO (ywang96): Make this configurable.
max_num_encoder_input_tokens: int = field(init=False)
"""Multimodal encoder compute budget, only used in V1.
NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger."""
# TODO (ywang96): Make this configurable.
encoder_cache_size: int = field(init=False)
"""Multimodal encoder cache size, only used in V1.
NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger."""
# NOTE: The following multimodal encoder budget will be initialized to
# max_num_batched_tokens and overridden in case max multimodal embedding
# size is larger.
# TODO (ywang96): Make these configurable.
# Multimodal encoder compute budget, only used in V1
max_num_encoder_input_tokens: int = field(default=None) # type: ignore
# Multimodal encoder cache size, only used in V1
encoder_cache_size: int = field(default=None) # type: ignore
# Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not currently supported. In
# such a case, we use swapping instead.
preemption_mode: Optional[str] = None
"""Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead."""
num_scheduler_steps: int = 1
"""Maximum number of forward steps per scheduler call."""
multi_step_stream_outputs: bool = False
multi_step_stream_outputs: bool = True
"""If False, then multi-step will stream outputs at the end of all steps"""
# Private API. If used, scheduler sends delta data to
# workers instead of an entire data. It should be enabled only
# when SPMD worker architecture is enabled. I.e.,
# VLLM_USE_RAY_SPMD_WORKER=1
send_delta_data: bool = False
# The scheduling policy to use. "fcfs" (default) or "priority".
policy: str = "fcfs"
"""Private API. If used, scheduler sends delta data to
workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1"""
policy: SchedulerPolicy = "fcfs"
"""The scheduling policy to use:\n
- "fcfs" means first come first served, i.e. requests are handled in order
of arrival.\n
- "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties)."""
chunked_prefill_enabled: bool = field(init=False)
"""True if chunked prefill is enabled."""
# If set to true and chunked prefill is enabled, we do not want to
# partially schedule a multimodal item. Only used in V1
# This ensures that if a request has a mixed prompt
# (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
# some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
# it will be scheduled as TTTT in one step and IIIIIIIIII in the next.
disable_chunked_mm_input: bool = False
"""If set to true and chunked prefill is enabled, we do not want to
partially schedule a multimodal item. Only used in V1
This ensures that if a request has a mixed prompt
(like text tokens TTTT followed by image tokens IIIIIIIIII) where only
some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
# or "mod.custom_class".
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
"""The scheduler class to use. "vllm.core.scheduler.Scheduler" is the
default scheduler. Can be a class directly or the path to a class of form
"mod.custom_class"."""
def compute_hash(self) -> str:
"""
......@@ -1862,6 +1897,18 @@ class SchedulerConfig:
return hash_str
def __post_init__(self) -> None:
if self.max_model_len is None:
self.max_model_len = 8192
logger.warning(
"max_model_len was is not set. Defaulting to arbitrary value "
"of %d.", self.max_model_len)
if self.max_num_seqs is None:
self.max_num_seqs = 128
logger.warning(
"max_num_seqs was is not set. Defaulting to arbitrary value "
"of %d.", self.max_num_seqs)
if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
if self.num_scheduler_steps > 1:
......
This diff is collapsed.
......@@ -11,7 +11,7 @@ import ssl
from collections.abc import Sequence
from typing import Optional, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
......@@ -79,7 +79,7 @@ class PromptAdapterParserAction(argparse.Action):
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host",
type=nullable_str,
type=optional_str,
default=None,
help="Host name.")
parser.add_argument("--port", type=int, default=8000, help="Port number.")
......@@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=["*"],
help="Allowed headers.")
parser.add_argument("--api-key",
type=nullable_str,
type=optional_str,
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument(
"--lora-modules",
type=nullable_str,
type=optional_str,
default=None,
nargs='+',
action=LoRAParserAction,
......@@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"\"base_model_name\": \"id\"}``")
parser.add_argument(
"--prompt-adapters",
type=nullable_str,
type=optional_str,
default=None,
nargs='+',
action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.")
parser.add_argument("--chat-template",
type=nullable_str,
type=optional_str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
......@@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'similar to OpenAI schema. '
'Example: ``[{"type": "text", "text": "Hello world!"}]``')
parser.add_argument("--response-role",
type=nullable_str,
type=optional_str,
default="assistant",
help="The role name to return if "
"``request.add_generation_prompt=true``.")
parser.add_argument("--ssl-keyfile",
type=nullable_str,
type=optional_str,
default=None,
help="The file path to the SSL key file.")
parser.add_argument("--ssl-certfile",
type=nullable_str,
type=optional_str,
default=None,
help="The file path to the SSL cert file.")
parser.add_argument("--ssl-ca-certs",
type=nullable_str,
type=optional_str,
default=None,
help="The CA certificates file.")
parser.add_argument(
......@@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
)
parser.add_argument(
"--root-path",
type=nullable_str,
type=optional_str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy."
)
parser.add_argument(
"--middleware",
type=nullable_str,
type=optional_str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
......
......@@ -12,7 +12,7 @@ import torch
from prometheus_client import start_http_server
from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger, logger
# yapf: disable
......@@ -61,7 +61,7 @@ def parse_args():
"to the output URL.",
)
parser.add_argument("--response-role",
type=nullable_str,
type=optional_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=True`.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment