Unverified Commit c5a66d16 authored by junuxyz's avatar junuxyz Committed by GitHub
Browse files

[Core][BugFix] Fix PP KV cache sharding memory validation (#33698)


Signed-off-by: default avatarjunuxyz <216036880+junuxyz@users.noreply.github.com>
parent afdce12c
...@@ -1046,6 +1046,99 @@ def test_get_kv_cache_configs_multiple_workers(): ...@@ -1046,6 +1046,99 @@ def test_get_kv_cache_configs_multiple_workers():
) )
@pytest.mark.parametrize(
"asymmetric_memory",
[False, True],
ids=["symmetric", "asymmetric"],
)
def test_get_kv_cache_configs_pp_sharding(asymmetric_memory):
model_config = ModelConfig(max_model_len=512)
vllm_config = VllmConfig(model_config=model_config)
ref_kv_cache_spec = new_kv_cache_spec()
pp_kv_cache_specs = [
{"layer1": ref_kv_cache_spec},
{"layer2": ref_kv_cache_spec},
]
expected_num_blocks = model_config.max_model_len // ref_kv_cache_spec.block_size + 1
avail_memory = ref_kv_cache_spec.page_size_bytes * expected_num_blocks
# With per-worker validation, each worker only needs memory for its own
# layers. Worker 2 having more memory shouldn't affect worker 1's config.
available_memory = (
[avail_memory, avail_memory * 2] if asymmetric_memory else [avail_memory] * 2
)
kv_cache_configs = get_kv_cache_configs(
vllm_config,
pp_kv_cache_specs,
available_memory,
)
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=expected_num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=ref_kv_cache_spec.page_size_bytes * expected_num_blocks,
shared_by=["layer1"],
),
],
kv_cache_groups=[KVCacheGroupSpec(["layer1"], ref_kv_cache_spec)],
),
KVCacheConfig(
num_blocks=expected_num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=ref_kv_cache_spec.page_size_bytes * expected_num_blocks,
shared_by=["layer2"],
),
],
kv_cache_groups=[KVCacheGroupSpec(["layer2"], ref_kv_cache_spec)],
),
]
def test_project_kv_cache_groups_to_worker():
spec_a = new_kv_cache_spec()
spec_b = new_kv_cache_spec(num_kv_heads=4)
global_groups = [
KVCacheGroupSpec(["layer1", "layer2", "layer3"], spec_a),
]
worker_spec = {"layer1": spec_a, "layer2": spec_a}
projected = kv_cache_utils._project_kv_cache_groups_to_worker(
global_groups, worker_spec
)
assert len(projected) == 1
assert projected[0].layer_names == ["layer1", "layer2"]
assert projected[0].kv_cache_spec is spec_a
projected = kv_cache_utils._project_kv_cache_groups_to_worker(
global_groups, {"layer4": spec_a}
)
assert len(projected) == 1
assert projected[0].layer_names == []
assert projected[0].kv_cache_spec is spec_a
uniform_spec = UniformTypeKVCacheSpecs(
block_size=16,
kv_cache_specs={"layer1": spec_a, "layer2": spec_b, "layer3": spec_a},
)
global_groups_uniform = [
KVCacheGroupSpec(["layer1", "layer2", "layer3"], uniform_spec),
]
projected = kv_cache_utils._project_kv_cache_groups_to_worker(
global_groups_uniform, {"layer1": spec_a, "layer3": spec_a}
)
assert len(projected) == 1
assert projected[0].layer_names == ["layer1", "layer3"]
proj_spec = projected[0].kv_cache_spec
assert isinstance(proj_spec, UniformTypeKVCacheSpecs)
assert set(proj_spec.kv_cache_specs.keys()) == {"layer1", "layer3"}
def test_merge_kv_cache_spec(): def test_merge_kv_cache_spec():
same_layer_specs = [ same_layer_specs = [
new_kv_cache_spec(num_kv_heads=32), new_kv_cache_spec(num_kv_heads=32),
......
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from functools import partial
from typing import Any, NewType, TypeAlias, overload from typing import Any, NewType, TypeAlias, overload
from vllm import envs from vllm import envs
...@@ -1390,7 +1391,7 @@ def _estimate_max_model_len_from_groups( ...@@ -1390,7 +1391,7 @@ def _estimate_max_model_len_from_groups(
def _auto_fit_max_model_len( def _auto_fit_max_model_len(
vllm_config: VllmConfig, vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec], projected_groups_per_worker: list[list[KVCacheGroupSpec]],
available_memory: list[int], available_memory: list[int],
) -> None: ) -> None:
""" """
...@@ -1401,14 +1402,13 @@ def _auto_fit_max_model_len( ...@@ -1401,14 +1402,13 @@ def _auto_fit_max_model_len(
Args: Args:
vllm_config: The global VllmConfig (will be modified in-place) vllm_config: The global VllmConfig (will be modified in-place)
kv_cache_groups: The global KV cache groups (from get_kv_cache_groups). projected_groups_per_worker: KV cache groups projected to each worker.
This correctly accounts for padding in hybrid models.
available_memory: Memory available for KV cache in bytes for each available_memory: Memory available for KV cache in bytes for each
worker. worker.
""" """
original_max = vllm_config.model_config.max_model_len original_max = vllm_config.model_config.max_model_len
if not kv_cache_groups: if all(not groups for groups in projected_groups_per_worker):
# All workers have empty specs (attention-free model) # All workers have empty specs (attention-free model)
logger.info_once( logger.info_once(
"Auto-fit max_model_len: attention-free model, " "Auto-fit max_model_len: attention-free model, "
...@@ -1418,11 +1418,16 @@ def _auto_fit_max_model_len( ...@@ -1418,11 +1418,16 @@ def _auto_fit_max_model_len(
) )
return return
# Use minimum available memory across all workers # Find the max_model_len that fits across all workers.
min_available_memory = min(available_memory) auto_fit_max = original_max
auto_fit_max = _estimate_max_model_len_from_groups( limiting_worker_mem = available_memory[0]
vllm_config, kv_cache_groups, min_available_memory for groups, avail_mem in zip(projected_groups_per_worker, available_memory):
) if not groups:
continue
worker_max = _estimate_max_model_len_from_groups(vllm_config, groups, avail_mem)
if worker_max < auto_fit_max:
auto_fit_max = worker_max
limiting_worker_mem = avail_mem
if auto_fit_max <= 0: if auto_fit_max <= 0:
raise ValueError( raise ValueError(
...@@ -1446,11 +1451,47 @@ def _auto_fit_max_model_len( ...@@ -1446,11 +1451,47 @@ def _auto_fit_max_model_len(
"available GPU memory (%s GiB available for KV cache)", "available GPU memory (%s GiB available for KV cache)",
original_max, original_max,
auto_fit_max, auto_fit_max,
format_gib(min_available_memory), format_gib(limiting_worker_mem),
scope="local", scope="local",
) )
def _project_kv_cache_groups_to_worker(
global_kv_cache_groups: list[KVCacheGroupSpec],
worker_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]:
"""
Projects global KV cache groups onto a single worker's assigned layers.
In pipeline parallelism, each worker only owns a subset of layers. This
function filters the global groups to include only layers present on the
given worker, adjusting UniformTypeKVCacheSpecs accordingly.
Args:
global_kv_cache_groups: The global KV cache groups for the whole model.
worker_spec: The KV cache spec of each layer on this worker.
Returns:
The projected KV cache groups containing only this worker's layers.
"""
projected_groups: list[KVCacheGroupSpec] = []
for group in global_kv_cache_groups:
worker_layer_names = [
layer_name for layer_name in group.layer_names if layer_name in worker_spec
]
group_spec = group.kv_cache_spec
if worker_layer_names and isinstance(group_spec, UniformTypeKVCacheSpecs):
group_spec = UniformTypeKVCacheSpecs(
block_size=group_spec.block_size,
kv_cache_specs={
layer_name: group_spec.kv_cache_specs[layer_name]
for layer_name in worker_layer_names
},
)
projected_groups.append(KVCacheGroupSpec(worker_layer_names, group_spec))
return projected_groups
def get_kv_cache_configs( def get_kv_cache_configs(
vllm_config: VllmConfig, vllm_config: VllmConfig,
kv_cache_specs: list[dict[str, KVCacheSpec]], kv_cache_specs: list[dict[str, KVCacheSpec]],
...@@ -1468,7 +1509,8 @@ def get_kv_cache_configs( ...@@ -1468,7 +1509,8 @@ def get_kv_cache_configs(
the whole model. the whole model.
2. Generate the KV cache groups based on the layer ratio of the whole model. 2. Generate the KV cache groups based on the layer ratio of the whole model.
This also handles spec unification for hybrid models. This also handles spec unification for hybrid models.
3. Handle auto-fit max_model_len and memory checks using the unified specs. 3. Handle auto-fit max_model_len and memory checks using per-worker
projected groups to account for PP sharding.
4. Generate the KV cache configs for each worker based on the KV cache 4. Generate the KV cache configs for each worker based on the KV cache
grouping strategy. (This is reasonable because the layer ratio of grouping strategy. (This is reasonable because the layer ratio of
different PP stages are similar.) different PP stages are similar.)
...@@ -1506,44 +1548,38 @@ def get_kv_cache_configs( ...@@ -1506,44 +1548,38 @@ def get_kv_cache_configs(
# If original_max_model_len was -1, automatically # If original_max_model_len was -1, automatically
# determine the maximum model length that fits in available GPU memory. # determine the maximum model length that fits in available GPU memory.
# We use the global groups here to correctly account for padding. # We use per-worker projected groups to account for PP sharding.
projected_groups_per_worker = [
_project_kv_cache_groups_to_worker(global_kv_cache_groups, worker_spec)
for worker_spec in kv_cache_specs
]
if vllm_config.model_config.original_max_model_len == -1: if vllm_config.model_config.original_max_model_len == -1:
_auto_fit_max_model_len(vllm_config, global_kv_cache_groups, available_memory) _auto_fit_max_model_len(
vllm_config, projected_groups_per_worker, available_memory
)
# Check if the available memory is enough (using min across all workers). # Check if the available memory is enough per worker.
# We use the global groups to correctly account for padding. for groups, avail_mem in zip(projected_groups_per_worker, available_memory):
if global_kv_cache_groups: if not groups:
continue
_check_enough_kv_cache_memory( _check_enough_kv_cache_memory(
min(available_memory), avail_mem,
lambda: _max_memory_usage_bytes_from_groups( partial(_max_memory_usage_bytes_from_groups, vllm_config, groups),
vllm_config, global_kv_cache_groups
),
vllm_config.model_config.max_model_len, vllm_config.model_config.max_model_len,
lambda am: _estimate_max_model_len_from_groups( partial(_estimate_max_model_len_from_groups, vllm_config, groups),
vllm_config, global_kv_cache_groups, am
),
) )
kv_cache_configs: list[KVCacheConfig] = [] kv_cache_configs: list[KVCacheConfig] = []
for kv_cache_spec_one_worker, available_memory_one_worker in zip( for projected_groups, kv_cache_spec_one_worker, available_memory_one_worker in zip(
kv_cache_specs, available_memory projected_groups_per_worker, kv_cache_specs, available_memory
): ):
kv_cache_groups_one_worker: list[KVCacheGroupSpec] = [] assert sum(len(group.layer_names) for group in projected_groups) == len(
for group in global_kv_cache_groups: kv_cache_spec_one_worker
group_layer_names_one_worker = [ ), "Some layers are not assigned to any group."
layer_name
for layer_name in group.layer_names
if layer_name in kv_cache_spec_one_worker
]
kv_cache_groups_one_worker.append(
KVCacheGroupSpec(group_layer_names_one_worker, group.kv_cache_spec)
)
assert sum(
len(group.layer_names) for group in kv_cache_groups_one_worker
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
kv_cache_configs.append( kv_cache_configs.append(
get_kv_cache_config_from_groups( get_kv_cache_config_from_groups(
vllm_config, kv_cache_groups_one_worker, available_memory_one_worker vllm_config, projected_groups, available_memory_one_worker
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment