Unverified Commit 1e9438e0 authored by wangxiyuan's avatar wangxiyuan Committed by GitHub
Browse files

[MISC] Move bind_kv_cache to worker module (#20900)


Signed-off-by: default avatarwangxiyuan <wangxiyuan1007@gmail.com>
parent 697ef765
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import torch import torch
from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.utils import bind_kv_cache
def test_bind_kv_cache(): def test_bind_kv_cache():
......
...@@ -4,7 +4,6 @@ import argparse ...@@ -4,7 +4,6 @@ import argparse
import multiprocessing import multiprocessing
import time import time
import weakref import weakref
from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from multiprocessing import connection from multiprocessing import connection
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
...@@ -14,14 +13,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, ...@@ -14,14 +13,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri, from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
kill_process_tree) kill_process_tree)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.layer import Attention
from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.utils import (CoreEngineActorManager, from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager) CoreEngineProcManager)
...@@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]): ...@@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]):
kill_process_tree(pid) kill_process_tree(pid)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> torch.Tensor: length: int) -> torch.Tensor:
""" """
......
...@@ -62,13 +62,13 @@ from vllm.v1.spec_decode.eagle import EagleProposer ...@@ -62,13 +62,13 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, from .utils import (bind_kv_cache, gather_mm_placeholders,
initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders) sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
if TYPE_CHECKING: if TYPE_CHECKING:
......
...@@ -42,11 +42,10 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists, ...@@ -42,11 +42,10 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
LogprobsTensors, ModelRunnerOutput) LogprobsTensors, ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
from .utils import (initialize_kv_cache_for_kv_sharing, from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs) sanity_check_mm_encoder_outputs)
if TYPE_CHECKING: if TYPE_CHECKING:
......
...@@ -25,8 +25,9 @@ from vllm.v1.core.sched.output import SchedulerOutput ...@@ -25,8 +25,9 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache, report_usage_stats from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.tpu_model_runner import TPUModelRunner from vllm.v1.worker.tpu_model_runner import TPUModelRunner
from vllm.v1.worker.utils import bind_kv_cache
logger = init_logger(__name__) logger = init_logger(__name__)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from collections import defaultdict
from typing import TYPE_CHECKING, Optional
import torch import torch
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.v1.kv_cache_interface import KVCacheGroupSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec
if TYPE_CHECKING:
from vllm.attention.layer import Attention
def sanity_check_mm_encoder_outputs( def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings, mm_embeddings: MultiModalEmbeddings,
...@@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing( ...@@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing(
kv_caches[layer_name] = kv_caches[target_layer_name] kv_caches[layer_name] = kv_caches[target_layer_name]
group_idx = layer_to_kv_cache_group_idx[target_layer_name] group_idx = layer_to_kv_cache_group_idx[target_layer_name]
kv_cache_groups[group_idx].layer_names.append(layer_name) kv_cache_groups[group_idx].layer_names.append(layer_name)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment