Unverified Commit c687bf22 authored by chunxiaozheng's avatar chunxiaozheng Committed by GitHub
Browse files

[LMCache][MP] optimize save when mla enabled (#38810)


Signed-off-by: default avataridellzheng <idellzheng@tencent.com>
Co-authored-by: default avatarYihua Cheng <yihua98@uchicago.edu>
parent ccf90ba7
...@@ -7,6 +7,7 @@ from .multi_process_adapter import ( ...@@ -7,6 +7,7 @@ from .multi_process_adapter import (
LMCacheMPSchedulerAdapter, LMCacheMPSchedulerAdapter,
LMCacheMPWorkerAdapter, LMCacheMPWorkerAdapter,
LoadStoreOp, LoadStoreOp,
ParallelStrategy,
) )
__all__ = [ __all__ = [
...@@ -15,4 +16,5 @@ __all__ = [ ...@@ -15,4 +16,5 @@ __all__ = [
"LMCacheMPSchedulerAdapter", "LMCacheMPSchedulerAdapter",
"LMCacheMPWorkerAdapter", "LMCacheMPWorkerAdapter",
"LoadStoreOp", "LoadStoreOp",
"ParallelStrategy",
] ]
...@@ -79,6 +79,39 @@ def get_lmcache_chunk_size( ...@@ -79,6 +79,39 @@ def get_lmcache_chunk_size(
return chunk_size return chunk_size
@dataclass
class ParallelStrategy:
use_mla: bool
"""Whether to use the MLA."""
kv_world_size: int
"""
The kv world size, kv_world_size may not be equal to the actual_world_size,
in the case of mla, it will 'exclude' the effect of TP, the value is
calculated by `extract_world_size_and_kv_rank` in `lmcache_mp_connector.py`.
"""
kv_worker_id: int
"""
The kv worker id of the sub-process, kv_worker_id may not be equal to the
actual_worker_id, in the case of mla, it will 'exclude' the effect of TP,
the value is calculated by `extract_world_size_and_kv_rank` in
`lmcache_mp_connector.py`.
"""
actual_world_size: int
"""The actual world size."""
actual_worker_id: int
"""The actual worker id of the sub-process."""
tp_size: int
"""The tensor parallel size."""
pp_size: int
"""The pipeline parallel size."""
@dataclass @dataclass
class LoadStoreOp: class LoadStoreOp:
block_ids: list[int] block_ids: list[int]
...@@ -111,10 +144,8 @@ class LMCacheMPSchedulerAdapter: ...@@ -111,10 +144,8 @@ class LMCacheMPSchedulerAdapter:
server_url: str, server_url: str,
context: zmq.Context, context: zmq.Context,
model_name: str, model_name: str,
world_size: int,
kv_rank: int,
vllm_block_size: int, vllm_block_size: int,
tp_size: int = 1, parallel_strategy: ParallelStrategy,
): ):
""" """
Args: Args:
...@@ -122,11 +153,10 @@ class LMCacheMPSchedulerAdapter: ...@@ -122,11 +153,10 @@ class LMCacheMPSchedulerAdapter:
context: The ZMQ context context: The ZMQ context
model_name: The model name used for LMCache keys model_name: The model name used for LMCache keys
world_size: The world size used for LMCache keys
kv_rank: The kv rank used for LMCache keys
vllm_block_size: The block size used in vLLM vllm_block_size: The block size used in vLLM
tp_size: Tensor-parallel size for MLA parallel_strategy:
multi-reader locking (default 1). The parallel strategy, which includes `use_mla`,
`world_size`, `worker_id` and so on
""" """
self.mq_client = MessageQueueClient(server_url, context) self.mq_client = MessageQueueClient(server_url, context)
...@@ -134,9 +164,7 @@ class LMCacheMPSchedulerAdapter: ...@@ -134,9 +164,7 @@ class LMCacheMPSchedulerAdapter:
self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {} self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {}
self.model_name = model_name self.model_name = model_name
self.world_size = world_size self.parallel_strategy = parallel_strategy
self.worker_id = kv_rank
self.tp_size = tp_size
# Read chunk size from lmcache # Read chunk size from lmcache
self.chunk_size = get_lmcache_chunk_size(self.mq_client) self.chunk_size = get_lmcache_chunk_size(self.mq_client)
...@@ -145,6 +173,21 @@ class LMCacheMPSchedulerAdapter: ...@@ -145,6 +173,21 @@ class LMCacheMPSchedulerAdapter:
) )
self.blocks_in_chunk = self.chunk_size // vllm_block_size self.blocks_in_chunk = self.chunk_size // vllm_block_size
@property
def world_size(self) -> int:
"""The world size."""
return self.parallel_strategy.kv_world_size
@property
def worker_id(self) -> int:
"""The worker id."""
return self.parallel_strategy.kv_worker_id
@property
def tp_size(self) -> int:
"""The tensor parallel size."""
return self.parallel_strategy.tp_size
@_lmcache_nvtx_annotate @_lmcache_nvtx_annotate
def maybe_submit_lookup_request( def maybe_submit_lookup_request(
self, self,
...@@ -308,9 +351,8 @@ class LMCacheMPWorkerAdapter: ...@@ -308,9 +351,8 @@ class LMCacheMPWorkerAdapter:
server_url: str, server_url: str,
context: zmq.Context, context: zmq.Context,
model_name: str, model_name: str,
world_size: int,
kv_rank: int,
vllm_block_size: int, vllm_block_size: int,
parallel_strategy: ParallelStrategy,
): ):
self.mq_client = MessageQueueClient(server_url, context) self.mq_client = MessageQueueClient(server_url, context)
...@@ -336,8 +378,7 @@ class LMCacheMPWorkerAdapter: ...@@ -336,8 +378,7 @@ class LMCacheMPWorkerAdapter:
self.previously_finished: set[str] = set() self.previously_finished: set[str] = set()
self.model_name = model_name self.model_name = model_name
self.world_size = world_size self.parallel_strategy = parallel_strategy
self.worker_id = kv_rank
# Read chunk size from lmcache # Read chunk size from lmcache
chunk_size = get_lmcache_chunk_size(self.mq_client) chunk_size = get_lmcache_chunk_size(self.mq_client)
...@@ -346,6 +387,29 @@ class LMCacheMPWorkerAdapter: ...@@ -346,6 +387,29 @@ class LMCacheMPWorkerAdapter:
) )
self.blocks_in_chunk = chunk_size // vllm_block_size self.blocks_in_chunk = chunk_size // vllm_block_size
@property
def world_size(self) -> int:
"""The world size."""
return self.parallel_strategy.kv_world_size
@property
def worker_id(self) -> int:
"""The worker id."""
return self.parallel_strategy.kv_worker_id
@property
def use_mla(self) -> bool:
"""Whether to use MLA."""
return self.parallel_strategy.use_mla
@property
def is_first_rank_of_pp_group(self) -> bool:
"""Is the first rank of the pipeline parallel group."""
return (
self.parallel_strategy.actual_worker_id % self.parallel_strategy.tp_size
== 0
)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
""" """
Register the kv caches with LMCache server Register the kv caches with LMCache server
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum import enum
import inspect
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
...@@ -28,6 +27,7 @@ try: ...@@ -28,6 +27,7 @@ try:
LMCacheMPSchedulerAdapter, LMCacheMPSchedulerAdapter,
LMCacheMPWorkerAdapter, LMCacheMPWorkerAdapter,
LoadStoreOp, LoadStoreOp,
ParallelStrategy,
) )
try: try:
...@@ -45,6 +45,7 @@ except ImportError: ...@@ -45,6 +45,7 @@ except ImportError:
LMCacheMPSchedulerAdapter, LMCacheMPSchedulerAdapter,
LMCacheMPWorkerAdapter, LMCacheMPWorkerAdapter,
LoadStoreOp, LoadStoreOp,
ParallelStrategy,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -64,12 +65,6 @@ if TYPE_CHECKING: ...@@ -64,12 +65,6 @@ if TYPE_CHECKING:
logger = lmcache_init_logger(__name__) logger = lmcache_init_logger(__name__)
def _adapter_accepts_tp_size() -> bool:
"""Check if the imported adapter accepts tp_size."""
sig = inspect.signature(LMCacheMPSchedulerAdapter.__init__)
return "tp_size" in sig.parameters
# Helper functions # Helper functions
def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]: def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
if block_ids is None: if block_ids is None:
...@@ -105,8 +100,8 @@ def extract_world_size_and_kv_rank( ...@@ -105,8 +100,8 @@ def extract_world_size_and_kv_rank(
# vLLM constructs TP groups first, and then construct other # vLLM constructs TP groups first, and then construct other
# parallel groups on top of TP groups. # parallel groups on top of TP groups.
# for example, TP=4, PP=2, # for example, TP=4, PP=2,
# TP group: [0, 1, 2, 3], [4, 5, 6, 7] # PP group: [0, 1, 2, 3], [4, 5, 6, 7]
# PP group: [0, 4], [1, 5], [2, 6], [3, 7] # TP group: [0, 4], [1, 5], [2, 6], [3, 7]
# So we can "exclude" the effect of TP by rank // tp_size. # So we can "exclude" the effect of TP by rank // tp_size.
return world_size // tp_size, rank // tp_size return world_size // tp_size, rank // tp_size
...@@ -123,24 +118,24 @@ def create_scheduler_adapter( ...@@ -123,24 +118,24 @@ def create_scheduler_adapter(
vllm_config.parallel_config.rank, vllm_config.parallel_config.rank,
vllm_config, vllm_config,
) )
tp_size = vllm_config.parallel_config.tensor_parallel_size parallel_strategy = ParallelStrategy(
mla_enabled(vllm_config.model_config),
# Pass tp_size only when the adapter accepts it so that
# a newer vllm can still work with an older LMCache.
kwargs: dict[str, Any] = {}
if _adapter_accepts_tp_size():
kwargs["tp_size"] = tp_size
return LMCacheMPSchedulerAdapter(
server_url,
zmq_context,
vllm_config.model_config.model,
world_size, world_size,
kv_rank, kv_rank,
vllm_config.cache_config.block_size, vllm_config.parallel_config.world_size,
vllm_config.parallel_config.rank,
vllm_config.parallel_config.tensor_parallel_size,
vllm_config.parallel_config.pipeline_parallel_size,
)
return LMCacheMPSchedulerAdapter(
server_url=server_url,
context=zmq_context,
model_name=vllm_config.model_config.model,
vllm_block_size=vllm_config.cache_config.block_size,
parallel_strategy=parallel_strategy,
mq_timeout=mq_timeout, mq_timeout=mq_timeout,
heartbeat_interval=heartbeat_interval, heartbeat_interval=heartbeat_interval,
**kwargs,
) )
...@@ -156,13 +151,22 @@ def create_worker_adapter( ...@@ -156,13 +151,22 @@ def create_worker_adapter(
vllm_config.parallel_config.rank, vllm_config.parallel_config.rank,
vllm_config, vllm_config,
) )
return LMCacheMPWorkerAdapter( parallel_strategy = ParallelStrategy(
server_url, mla_enabled(vllm_config.model_config),
zmq_context,
vllm_config.model_config.model,
world_size, world_size,
kv_rank, kv_rank,
vllm_config.cache_config.block_size, vllm_config.parallel_config.world_size,
vllm_config.parallel_config.rank,
vllm_config.parallel_config.tensor_parallel_size,
vllm_config.parallel_config.pipeline_parallel_size,
)
return LMCacheMPWorkerAdapter(
server_url=server_url,
context=zmq_context,
model_name=vllm_config.model_config.model,
vllm_block_size=vllm_config.cache_config.block_size,
parallel_strategy=parallel_strategy,
mq_timeout=mq_timeout, mq_timeout=mq_timeout,
heartbeat_interval=heartbeat_interval, heartbeat_interval=heartbeat_interval,
) )
...@@ -612,6 +616,14 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -612,6 +616,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
This prevents overwrites of paged KV buffer before saving done. This prevents overwrites of paged KV buffer before saving done.
""" """
# In MLA scenario, only the first rank of the pipeline group
# needs to save the KV cache.
if (
self.worker_adapter.use_mla
and not self.worker_adapter.is_first_rank_of_pp_group
):
return
metadata = self._get_connector_metadata() metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata) assert isinstance(metadata, LMCacheMPConnectorMetadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment