Unverified Commit 7adeb4bf authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix `max_model_len="auto"` handling (#31260)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent bd89ce16
...@@ -164,7 +164,7 @@ class ModelConfig: ...@@ -164,7 +164,7 @@ class ModelConfig:
"""The specific revision to use for the tokenizer on the Hugging Face Hub. """The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version.""" use the default version."""
max_model_len: int = Field(default=None, gt=0) max_model_len: int = Field(default=None, ge=-1)
"""Model context length (prompt and output). If unspecified, will be """Model context length (prompt and output). If unspecified, will be
automatically derived from the model config. automatically derived from the model config.
......
...@@ -297,16 +297,14 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]: ...@@ -297,16 +297,14 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
elif contains_type(type_hints, set): elif contains_type(type_hints, set):
kwargs[name].update(collection_to_kwargs(type_hints, set)) kwargs[name].update(collection_to_kwargs(type_hints, set))
elif contains_type(type_hints, int): elif contains_type(type_hints, int):
kwargs[name]["type"] = int if name == "max_model_len":
# Special case for large integers kwargs[name]["type"] = human_readable_int_or_auto
human_readable_ints = { kwargs[name]["help"] += f"\n\n{human_readable_int_or_auto.__doc__}"
"max_model_len", elif name in ("max_num_batched_tokens", "kv_cache_memory_bytes"):
"max_num_batched_tokens",
"kv_cache_memory_bytes",
}
if name in human_readable_ints:
kwargs[name]["type"] = human_readable_int kwargs[name]["type"] = human_readable_int
kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}" kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
else:
kwargs[name]["type"] = int
elif contains_type(type_hints, float): elif contains_type(type_hints, float):
kwargs[name]["type"] = float kwargs[name]["type"] = float
elif contains_type(type_hints, dict) and ( elif contains_type(type_hints, dict) and (
...@@ -2042,23 +2040,17 @@ def _raise_unsupported_error(feature_name: str): ...@@ -2042,23 +2040,17 @@ def _raise_unsupported_error(feature_name: str):
raise NotImplementedError(msg) raise NotImplementedError(msg)
def human_readable_int(value): def human_readable_int(value: str) -> int:
"""Parse human-readable integers like '1k', '2M', etc. """Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers. Including decimal values with decimal multipliers.
Also accepts -1 or 'auto' as a special value for auto-detection.
Examples: Examples:
- '1k' -> 1,000 - '1k' -> 1,000
- '1K' -> 1,024 - '1K' -> 1,024
- '25.6k' -> 25,600 - '25.6k' -> 25,600
- '-1' or 'auto' -> -1 (special value for auto-detection)
""" """
value = value.strip() value = value.strip()
# Handle -1 or 'auto' as a special value for auto-detection
if value == "-1" or value.lower() == "auto":
return -1
match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value) match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
if match: if match:
decimal_multiplier = { decimal_multiplier = {
...@@ -2092,3 +2084,22 @@ def human_readable_int(value): ...@@ -2092,3 +2084,22 @@ def human_readable_int(value):
# Regular plain number. # Regular plain number.
return int(value) return int(value)
def human_readable_int_or_auto(value: str) -> int:
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Also accepts -1 or 'auto' as a special value for auto-detection.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024
- '25.6k' -> 25,600
- '-1' or 'auto' -> -1 (special value for auto-detection)
"""
value = value.strip()
if value == "-1" or value.lower() == "auto":
return -1
return human_readable_int(value)
...@@ -606,6 +606,43 @@ def get_request_block_hasher( ...@@ -606,6 +606,43 @@ def get_request_block_hasher(
return request_block_hasher return request_block_hasher
def _check_enough_kv_cache_memory(
available_memory: int,
get_needed_memory: Callable[[], int],
max_model_len: int,
estimate_max_model_len: Callable[[int], int],
):
if available_memory <= 0:
raise ValueError(
"No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when initializing the engine. "
"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
"for more details."
)
needed_memory = get_needed_memory()
if needed_memory > available_memory:
estimated_max_len = estimate_max_model_len(available_memory)
estimated_msg = ""
if estimated_max_len > 0:
estimated_msg = (
"Based on the available memory, "
f"the estimated maximum model length is {estimated_max_len}. "
)
raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory / GiB_bytes:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory / GiB_bytes:.2f} GiB). {estimated_msg}"
f"Try increasing `gpu_memory_utilization` or decreasing `max_model_len` "
f"when initializing the engine. "
f"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
f"for more details."
)
def max_memory_usage_bytes( def max_memory_usage_bytes(
vllm_config: VllmConfig, kv_cache_specs: Iterable[KVCacheSpec] vllm_config: VllmConfig, kv_cache_specs: Iterable[KVCacheSpec]
) -> int: ) -> int:
...@@ -688,43 +725,12 @@ def check_enough_kv_cache_memory( ...@@ -688,43 +725,12 @@ def check_enough_kv_cache_memory(
""" """
# No need to check for available memory if the kv_cache_spec is empty # No need to check for available memory if the kv_cache_spec is empty
if not kv_cache_spec: if kv_cache_spec:
return _check_enough_kv_cache_memory(
available_memory,
if available_memory <= 0: lambda: max_memory_usage_bytes(vllm_config, kv_cache_spec.values()),
raise ValueError( vllm_config.model_config.max_model_len,
"No available memory for the cache blocks. " lambda am: estimate_max_model_len(vllm_config, kv_cache_spec, am),
"Try increasing `gpu_memory_utilization` when "
"initializing the engine. "
"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
"for more details."
)
max_model_len = vllm_config.model_config.max_model_len
needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values())
if needed_memory > available_memory:
# Estimate the maximum model length that can fit in the available memory
estimated_max_len = estimate_max_model_len(
vllm_config, kv_cache_spec, available_memory
)
estimated_msg = ""
if estimated_max_len > 0:
estimated_msg = (
"Based on the available memory, "
f"the estimated maximum model length is {estimated_max_len}."
)
raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory / GiB_bytes:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory / GiB_bytes:.2f} GiB). "
f"{estimated_msg} "
f"Try increasing `gpu_memory_utilization` or decreasing `max_model_len` "
f"when initializing the engine. "
f"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
f"for more details."
) )
...@@ -1505,35 +1511,15 @@ def get_kv_cache_configs( ...@@ -1505,35 +1511,15 @@ def get_kv_cache_configs(
# Check if the available memory is enough (using min across all workers). # Check if the available memory is enough (using min across all workers).
# We use the global groups to correctly account for padding. # We use the global groups to correctly account for padding.
if global_kv_cache_groups: if global_kv_cache_groups:
min_available_memory = min(available_memory) _check_enough_kv_cache_memory(
if min_available_memory <= 0: min(available_memory),
raise ValueError( lambda: _max_memory_usage_bytes_from_groups(
"No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine."
)
max_model_len = vllm_config.model_config.max_model_len
needed_memory = _max_memory_usage_bytes_from_groups(
vllm_config, global_kv_cache_groups vllm_config, global_kv_cache_groups
) ),
if needed_memory > min_available_memory: vllm_config.model_config.max_model_len,
estimated_max_len = _estimate_max_model_len_from_groups( lambda am: _estimate_max_model_len_from_groups(
vllm_config, global_kv_cache_groups, min_available_memory vllm_config, global_kv_cache_groups, am
) ),
estimated_msg = ""
if estimated_max_len > 0:
estimated_msg = (
f"Based on the available memory, the estimated maximum "
f"model length is {estimated_max_len}. "
)
raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory / GiB_bytes:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({min_available_memory / GiB_bytes:.2f} GiB). "
f"{estimated_msg}"
f"Try increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine."
) )
kv_cache_configs: list[KVCacheConfig] = [] kv_cache_configs: list[KVCacheConfig] = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment