Unverified Commit c687bf22 authored by chunxiaozheng's avatar chunxiaozheng Committed by GitHub
Browse files

[LMCache][MP] optimize save when mla enabled (#38810)


Signed-off-by: default avataridellzheng <idellzheng@tencent.com>
Co-authored-by: default avatarYihua Cheng <yihua98@uchicago.edu>
parent ccf90ba7
......@@ -7,6 +7,7 @@ from .multi_process_adapter import (
LMCacheMPSchedulerAdapter,
LMCacheMPWorkerAdapter,
LoadStoreOp,
ParallelStrategy,
)
__all__ = [
......@@ -15,4 +16,5 @@ __all__ = [
"LMCacheMPSchedulerAdapter",
"LMCacheMPWorkerAdapter",
"LoadStoreOp",
"ParallelStrategy",
]
......@@ -79,6 +79,39 @@ def get_lmcache_chunk_size(
return chunk_size
@dataclass
class ParallelStrategy:
use_mla: bool
"""Whether to use the MLA."""
kv_world_size: int
"""
The kv world size, kv_world_size may not be equal to the actual_world_size,
in the case of mla, it will 'exclude' the effect of TP, the value is
calculated by `extract_world_size_and_kv_rank` in `lmcache_mp_connector.py`.
"""
kv_worker_id: int
"""
The kv worker id of the sub-process, kv_worker_id may not be equal to the
actual_worker_id, in the case of mla, it will 'exclude' the effect of TP,
the value is calculated by `extract_world_size_and_kv_rank` in
`lmcache_mp_connector.py`.
"""
actual_world_size: int
"""The actual world size."""
actual_worker_id: int
"""The actual worker id of the sub-process."""
tp_size: int
"""The tensor parallel size."""
pp_size: int
"""The pipeline parallel size."""
@dataclass
class LoadStoreOp:
block_ids: list[int]
......@@ -111,10 +144,8 @@ class LMCacheMPSchedulerAdapter:
server_url: str,
context: zmq.Context,
model_name: str,
world_size: int,
kv_rank: int,
vllm_block_size: int,
tp_size: int = 1,
parallel_strategy: ParallelStrategy,
):
"""
Args:
......@@ -122,11 +153,10 @@ class LMCacheMPSchedulerAdapter:
context: The ZMQ context
model_name: The model name used for LMCache keys
world_size: The world size used for LMCache keys
kv_rank: The kv rank used for LMCache keys
vllm_block_size: The block size used in vLLM
tp_size: Tensor-parallel size for MLA
multi-reader locking (default 1).
parallel_strategy:
The parallel strategy, which includes `use_mla`,
`world_size`, `worker_id` and so on
"""
self.mq_client = MessageQueueClient(server_url, context)
......@@ -134,9 +164,7 @@ class LMCacheMPSchedulerAdapter:
self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {}
self.model_name = model_name
self.world_size = world_size
self.worker_id = kv_rank
self.tp_size = tp_size
self.parallel_strategy = parallel_strategy
# Read chunk size from lmcache
self.chunk_size = get_lmcache_chunk_size(self.mq_client)
......@@ -145,6 +173,21 @@ class LMCacheMPSchedulerAdapter:
)
self.blocks_in_chunk = self.chunk_size // vllm_block_size
@property
def world_size(self) -> int:
"""The world size."""
return self.parallel_strategy.kv_world_size
@property
def worker_id(self) -> int:
"""The worker id."""
return self.parallel_strategy.kv_worker_id
@property
def tp_size(self) -> int:
"""The tensor parallel size."""
return self.parallel_strategy.tp_size
@_lmcache_nvtx_annotate
def maybe_submit_lookup_request(
self,
......@@ -308,9 +351,8 @@ class LMCacheMPWorkerAdapter:
server_url: str,
context: zmq.Context,
model_name: str,
world_size: int,
kv_rank: int,
vllm_block_size: int,
parallel_strategy: ParallelStrategy,
):
self.mq_client = MessageQueueClient(server_url, context)
......@@ -336,8 +378,7 @@ class LMCacheMPWorkerAdapter:
self.previously_finished: set[str] = set()
self.model_name = model_name
self.world_size = world_size
self.worker_id = kv_rank
self.parallel_strategy = parallel_strategy
# Read chunk size from lmcache
chunk_size = get_lmcache_chunk_size(self.mq_client)
......@@ -346,6 +387,29 @@ class LMCacheMPWorkerAdapter:
)
self.blocks_in_chunk = chunk_size // vllm_block_size
@property
def world_size(self) -> int:
"""The world size."""
return self.parallel_strategy.kv_world_size
@property
def worker_id(self) -> int:
"""The worker id."""
return self.parallel_strategy.kv_worker_id
@property
def use_mla(self) -> bool:
"""Whether to use MLA."""
return self.parallel_strategy.use_mla
@property
def is_first_rank_of_pp_group(self) -> bool:
"""Is the first rank of the pipeline parallel group."""
return (
self.parallel_strategy.actual_worker_id % self.parallel_strategy.tp_size
== 0
)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
Register the kv caches with LMCache server
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import inspect
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal
......@@ -28,6 +27,7 @@ try:
LMCacheMPSchedulerAdapter,
LMCacheMPWorkerAdapter,
LoadStoreOp,
ParallelStrategy,
)
try:
......@@ -45,6 +45,7 @@ except ImportError:
LMCacheMPSchedulerAdapter,
LMCacheMPWorkerAdapter,
LoadStoreOp,
ParallelStrategy,
)
if TYPE_CHECKING:
......@@ -64,12 +65,6 @@ if TYPE_CHECKING:
logger = lmcache_init_logger(__name__)
def _adapter_accepts_tp_size() -> bool:
"""Check if the imported adapter accepts tp_size."""
sig = inspect.signature(LMCacheMPSchedulerAdapter.__init__)
return "tp_size" in sig.parameters
# Helper functions
def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
if block_ids is None:
......@@ -105,8 +100,8 @@ def extract_world_size_and_kv_rank(
# vLLM constructs TP groups first, and then construct other
# parallel groups on top of TP groups.
# for example, TP=4, PP=2,
# TP group: [0, 1, 2, 3], [4, 5, 6, 7]
# PP group: [0, 4], [1, 5], [2, 6], [3, 7]
# PP group: [0, 1, 2, 3], [4, 5, 6, 7]
# TP group: [0, 4], [1, 5], [2, 6], [3, 7]
# So we can "exclude" the effect of TP by rank // tp_size.
return world_size // tp_size, rank // tp_size
......@@ -123,24 +118,24 @@ def create_scheduler_adapter(
vllm_config.parallel_config.rank,
vllm_config,
)
tp_size = vllm_config.parallel_config.tensor_parallel_size
# Pass tp_size only when the adapter accepts it so that
# a newer vllm can still work with an older LMCache.
kwargs: dict[str, Any] = {}
if _adapter_accepts_tp_size():
kwargs["tp_size"] = tp_size
return LMCacheMPSchedulerAdapter(
server_url,
zmq_context,
vllm_config.model_config.model,
parallel_strategy = ParallelStrategy(
mla_enabled(vllm_config.model_config),
world_size,
kv_rank,
vllm_config.cache_config.block_size,
vllm_config.parallel_config.world_size,
vllm_config.parallel_config.rank,
vllm_config.parallel_config.tensor_parallel_size,
vllm_config.parallel_config.pipeline_parallel_size,
)
return LMCacheMPSchedulerAdapter(
server_url=server_url,
context=zmq_context,
model_name=vllm_config.model_config.model,
vllm_block_size=vllm_config.cache_config.block_size,
parallel_strategy=parallel_strategy,
mq_timeout=mq_timeout,
heartbeat_interval=heartbeat_interval,
**kwargs,
)
......@@ -156,13 +151,22 @@ def create_worker_adapter(
vllm_config.parallel_config.rank,
vllm_config,
)
return LMCacheMPWorkerAdapter(
server_url,
zmq_context,
vllm_config.model_config.model,
parallel_strategy = ParallelStrategy(
mla_enabled(vllm_config.model_config),
world_size,
kv_rank,
vllm_config.cache_config.block_size,
vllm_config.parallel_config.world_size,
vllm_config.parallel_config.rank,
vllm_config.parallel_config.tensor_parallel_size,
vllm_config.parallel_config.pipeline_parallel_size,
)
return LMCacheMPWorkerAdapter(
server_url=server_url,
context=zmq_context,
model_name=vllm_config.model_config.model,
vllm_block_size=vllm_config.cache_config.block_size,
parallel_strategy=parallel_strategy,
mq_timeout=mq_timeout,
heartbeat_interval=heartbeat_interval,
)
......@@ -612,6 +616,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
This prevents overwrites of paged KV buffer before saving done.
"""
# In MLA scenario, only the first rank of the pipeline group
# needs to save the KV cache.
if (
self.worker_adapter.use_mla
and not self.worker_adapter.is_first_rank_of_pp_group
):
return
metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment