diff --git a/vllm/config.py b/vllm/config.py index 9ba49757..7e871521 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2629,7 +2629,7 @@ class KVTransferConfig(BaseModel): kv_buffer_size: float = 1e9 # Whether this vLLM instance produces, consumes KV cache, or both. Choices - # are 'kv_producer', 'kv_consumer', and 'both'. + # are 'kv_producer', 'kv_consumer', and 'kv_both'. kv_role: Optional[str] = None # The rank of this vLLM instance in the KV cache transfer. Typical value: @@ -2647,6 +2647,14 @@ class KVTransferConfig(BaseModel): # The KV connector port, used to build distributed connection kv_port: int = 14579 + + # This does not need to be set by the user. It is set by the connector. + kv_producers_parallel_size: Optional[int] = None + kv_producers_tensor_parallel_size: Optional[int] = None + kv_producers_pipeline_parallel_size: Optional[int] = None + kv_consumers_tensor_parallel_size: Optional[int] = None + kv_consumers_pipeline_parallel_size: Optional[int] = None + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -2685,6 +2693,7 @@ class KVTransferConfig(BaseModel): "is set, supported roles are `kv_producer`, " "`kv_consumer`, and `kv_both`") + @property def is_kv_transfer_instance(self) -> bool: return self.kv_connector is not None and \ @@ -2706,6 +2715,18 @@ class KVTransferConfig(BaseModel): return self.kv_connector is not None and \ self.kv_role in ["kv_consumer", "kv_both"] + @property + def tensor_parallel_multiplier(self) -> int: + return self.kv_consumers_tensor_parallel_size // self.kv_producers_tensor_parallel_size + + @property + def kv_consumers_parallel_size(self) -> int: + return self.kv_parallel_size - self.kv_producers_parallel_size + + @property + def kv_world_size(self) -> int: + return self.kv_producers_parallel_size + self.kv_consumers_parallel_size * self.tensor_parallel_multiplier + class CompilationLevel: # constants for the levels of the compilation process diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 359b5b26..d52ee050 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -6,6 +6,7 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, DeviceAwareBlockAllocator) from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.core.event_manager import KVCacheEventManager from vllm.platforms import current_platform from vllm.utils import Device @@ -28,6 +29,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): num_gpu_blocks: int, num_cpu_blocks: int, block_size: int, + event_manager: Optional[KVCacheEventManager] = None, ) -> DeviceAwareBlockAllocator: """Creates a CpuGpuBlockAllocator instance with the specified configuration. @@ -64,6 +66,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): cpu_block_ids = block_ids[num_gpu_blocks:] if allocator_type == "naive": + assert event_manager is None, "Event API not supported with naive allocator." gpu_allocator: BlockAllocator = NaiveBlockAllocator( create_block=NaiveBlock, # type: ignore num_blocks=num_gpu_blocks, @@ -82,12 +85,14 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): num_blocks=num_gpu_blocks, block_size=block_size, block_ids=gpu_block_ids, + event_manager=event_manager, ) cpu_allocator = PrefixCachingBlockAllocator( num_blocks=num_cpu_blocks, block_size=block_size, block_ids=cpu_block_ids, + event_manager=event_manager, ) else: raise ValueError(f"Unknown allocator type {allocator_type=}") @@ -95,10 +100,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): return CpuGpuBlockAllocator( cpu_block_allocator=cpu_allocator, gpu_block_allocator=gpu_allocator, + event_manager=event_manager, ) def __init__(self, cpu_block_allocator: BlockAllocator, - gpu_block_allocator: BlockAllocator): + gpu_block_allocator: BlockAllocator, + event_manager: Optional[KVCacheEventManager] = None,): assert not ( cpu_block_allocator.all_block_ids & gpu_block_allocator.all_block_ids @@ -108,6 +115,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): Device.CPU: cpu_block_allocator, Device.GPU: gpu_block_allocator, } + self.event_manager = event_manager self._swap_mapping: Dict[int, int] = {} self._null_block: Optional[Block] = None diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 1ca9e49d..b1591c0c 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -4,7 +4,7 @@ import sys from bisect import bisect_left from os.path import commonprefix from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, - Tuple) + Tuple, TYPE_CHECKING) from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) @@ -23,6 +23,9 @@ PrefixHash = int # then we know this block hasn't been accessed yet. _DEFAULT_LAST_ACCESSED_TIME = -1 +if TYPE_CHECKING: + from vllm.core.event_manager import KVCacheEventManager + logger = init_logger(__name__) @@ -80,6 +83,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): block_size: int, block_ids: Optional[Iterable[int]] = None, eviction_policy: EvictionPolicy = EvictionPolicy.LRU, + event_manager: Optional["KVCacheEventManager"] = None, ): if block_ids is None: block_ids = range(num_blocks) @@ -131,6 +135,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): self.metric_data = CacheMetricData() + self.event_manager = event_manager + + # Implements Block.Factory. def _create_block( self, prev_block: Optional[Block], @@ -337,6 +344,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): assert self._refcounter.get(_block_id) == 0 assert _block_id == block_id + if self.event_manager: + self.event_manager.enqueue_removed_event(content_hash_to_evict) + self._cached_blocks.pop(content_hash_to_evict) self._refcounter.incr(block_id) @@ -513,6 +523,10 @@ class PrefixCachingBlockAllocator(BlockAllocator): # Mark this block as touched so that it can be marked as # computed after the entire batch of sequences are scheduled. self._touched_blocks.add(block.block_id) + + if self.event_manager: + self.event_manager.enqueue_stored_event(block.prev_block, block) + return block.block_id # Reuse the cached content hash diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index c5b3b04f..c72001f7 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -10,7 +10,10 @@ from vllm.core.block.interfaces import Block from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, LastAccessBlocksTracker) from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec +from vllm.core.event_manager import KVCacheEventManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.envs import (VLLM_KV_CAPI_PATH, VLLM_KV_COMPONENT, VLLM_KV_NAMESPACE, + VLLM_WORKER_ID) from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device @@ -60,6 +63,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): def __init__( self, + model_name: str, block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, @@ -91,11 +95,28 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): self.watermark_blocks = int(watermark * num_gpu_blocks) + kv_event_manager_params = [ + VLLM_WORKER_ID, VLLM_KV_CAPI_PATH, VLLM_KV_NAMESPACE, + VLLM_KV_COMPONENT + ] + set_kv_event_manager_params = len( + [param for param in kv_event_manager_params if param is not None]) + + if set_kv_event_manager_params == len(kv_event_manager_params): + self.event_manager = KVCacheEventManager( + namespace=VLLM_KV_NAMESPACE, + component=VLLM_KV_COMPONENT, + worker_id=VLLM_WORKER_ID, + lib_path=VLLM_KV_CAPI_PATH) + else: + self.event_manager = None + self.block_allocator = CpuGpuBlockAllocator.create( allocator_type="prefix_caching" if enable_caching else "naive", num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=block_size, + event_manager=self.event_manager, ) self.block_tables: Dict[SeqId, BlockTable] = {} diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py new file mode 100644 index 00000000..350453cd --- /dev/null +++ b/vllm/core/event_manager.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +import ctypes +import logging +import uuid +from ctypes import c_char_p, c_size_t, c_uint32, c_void_p, c_int64 +from typing import Optional + +from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash + +logger = logging.getLogger(__name__) + + +class TritonResult: + OK = 0 + ERR = 1 + + +class KVCacheEventManager: + + def __init__(self, namespace: str, component: str, worker_id: int, + lib_path: str): + self.lib = None + + try: + self.lib = ctypes.CDLL(lib_path) + self.lib.triton_llm_init.argtypes = [c_char_p, c_char_p, c_int64] + self.lib.triton_llm_init.restype = c_uint32 + + result = self.lib.triton_llm_init(namespace.encode(), + component.encode(), worker_id) + if result == TritonResult.OK: + logger.info( + "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events" + ) + else: + logger.info("KVCacheEventManager initialization failed!") + + except Exception as e: + print(f"Failed to load {lib_path}") + raise e + + self.lib.triton_kv_event_publish_stored.argtypes = [ + ctypes.c_uint64, # event_id + ctypes.POINTER(ctypes.c_uint32), # token_ids + ctypes.POINTER(ctypes.c_size_t), # num_block_tokens + ctypes.POINTER(ctypes.c_uint64), # block_ids + ctypes.c_size_t, # num_blocks + ctypes.POINTER(ctypes.c_uint64), # parent_hash + ctypes.c_uint64, # lora_id + ] + self.lib.triton_kv_event_publish_stored.restype = ctypes.c_uint32 # triton_llm_result_t + + self.lib.triton_kv_event_publish_removed.argtypes = [ + ctypes.c_uint64, # event_id + ctypes.POINTER(ctypes.c_uint64), # block_ids + ctypes.c_size_t, # num_blocks + ] + self.lib.triton_kv_event_publish_removed.restype = ctypes.c_uint32 # triton_llm_result_t + + self.event_id_counter = 0 + + def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock], + block: PrefixCachingBlock): + token_ids_arr = (ctypes.c_uint32 * + len(block.token_ids))(*block.token_ids) + num_block_tokens = (ctypes.c_size_t * 1)(len(block.token_ids)) + block_hash = (ctypes.c_uint64 * 1)(block.content_hash) + parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash) + if parent is not None else None) + + # Publish the event + result = self.lib.triton_kv_event_publish_stored( + self.event_id_counter, # uint64_t event_id + token_ids_arr, # const uint32_t *token_ids + num_block_tokens, # const uintptr_t *num_block_tokens + block_hash, # const uint64_t *block_ids + 1, # uintptr_t num_blocks + parent_hash, # const uint64_t *parent_hash + 0, # uint64_t lora_id + ) + + if result == TritonResult.OK: + logger.debug(f"Store - Published KV Event: {block.content_hash}") + else: + logger.debug( + f"Store - Failed to Publish KV Event: {block.content_hash}") + + self.event_id_counter += 1 + + def enqueue_removed_event(self, block_hash: PrefixHash): + result = self.lib.triton_kv_event_publish_removed( + self.event_id_counter, + (ctypes.c_uint64 * 1)(block_hash), + 1, + ) + + if result == TritonResult.OK: + logger.debug(f"Remove - Published KV Event: {block_hash}") + else: + logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}") + + self.event_id_counter += 1 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f507847a..6af77646 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -10,7 +10,7 @@ from typing import Callable, Deque, Dict, Iterable, List, Optional from typing import Sequence as GenericSequence from typing import Set, Tuple, Union -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import ModelConfig, CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -325,12 +325,14 @@ class Scheduler: def __init__( self, + model_config: ModelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], pipeline_parallel_size: int = 1, output_proc_callback: Optional[Callable] = None, ) -> None: + self.model_config = model_config self.scheduler_config = scheduler_config self.cache_config = cache_config # Note for LoRA scheduling: the current policy is extremely @@ -356,6 +358,7 @@ class Scheduler: # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( + model_name=self.model_config.served_model_name, block_size=self.cache_config.block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index fe480533..61a357d0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -27,13 +27,13 @@ class KVConnectorFactory: @classmethod def create_connector(cls, rank: int, local_rank: int, - config: "VllmConfig") -> KVConnectorBase: + config: "VllmConfig", world_group) -> KVConnectorBase: connector_name = config.kv_transfer_config.kv_connector if connector_name not in cls._registry: raise ValueError(f"Unsupported connector type: {connector_name}") connector_cls = cls._registry[connector_name]() - return connector_cls(rank, local_rank, config) + return connector_cls(rank, local_rank, config, world_group) # Register various connectors here. @@ -48,3 +48,8 @@ KVConnectorFactory.register_connector( "MooncakeConnector", "vllm.distributed.kv_transfer.kv_connector.simple_connector", "SimpleConnector") + +KVConnectorFactory.register_connector( + "TritonNcclConnector", + "vllm.distributed.kv_transfer.kv_connector.triton_connector", + "TritonConnector") \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 2033e976..e33919c1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -8,13 +8,15 @@ MooncakePipe. But the logic can be extended to support other pipe and lookup buffer. """ +import re from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch from vllm import _custom_ops as ops -from vllm.config import VllmConfig +from vllm.config import VllmConfig, KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( SimpleBuffer) from vllm.logger import init_logger @@ -33,6 +35,7 @@ class SimpleConnector(KVConnectorBase): rank: int, local_rank: int, config: VllmConfig, + world_group, ): self.config = config.kv_transfer_config @@ -71,20 +74,31 @@ class SimpleConnector(KVConnectorBase): self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] + self._broadcast_and_enhance_kv_config(rank, config, world_group) + + self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) + self.tp_size = config.parallel_config.tensor_parallel_size + # 2 pipes for every rank in the world - port_offset_base = 2 * rank + if self.config.is_kv_producer: + port_offset_base = 2 * rank + 1 + else: + port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1 + self.local_kv_rank = rank % self.config.tensor_parallel_multiplier # In disaggregated prefill, the prefill vLLM only uses send pipe # and the decode vLLM only uses recv pipe if self.config.is_kv_producer: if self.config.kv_connector == "PyNcclConnector": self.producer_data_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base, ) self.producer_signal_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base + 1, @@ -108,11 +122,13 @@ class SimpleConnector(KVConnectorBase): # its recv pipe to the send pipe of KV producder if self.config.kv_connector == "PyNcclConnector": self.consumer_data_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base, ) self.consumer_signal_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base + 1, @@ -131,21 +147,25 @@ class SimpleConnector(KVConnectorBase): self.config.kv_buffer_size, ) - def select(self, input_tokens: Optional[torch.Tensor], + def select(self, source_rank: int, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + logger.info("Selecting KV caches and hidden states for source rank %d", source_rank) + assert self.consumer_buffer is not None, "Please initialize the "\ "consumer buffer before calling select." - return self.consumer_buffer.drop_select(input_tokens, roi) + return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi) - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: + logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank) + assert self.producer_buffer is not None, "Please initialize the "\ "producer buffer before calling insert." - self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden) def send_kv_caches_and_hidden_states( self, @@ -161,12 +181,20 @@ class SimpleConnector(KVConnectorBase): slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer + request_ids = list(model_input.request_ids_to_seq_ids.keys()) model_config = model_executable.model.config - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - head_size = int(hidden_size / num_attention_heads) + is_deepseek = "deepseek" in model_config.architectures[0].lower() + if not is_deepseek: + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(hidden_size / num_attention_heads) + else: + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(4.5 * hidden_size / num_attention_heads) # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance @@ -175,27 +203,40 @@ class SimpleConnector(KVConnectorBase): start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + _, decode_kv_rank = self.parse_request_id(current_request_id) + starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config) + + for target_rank in range(self.config.tensor_parallel_multiplier): - keys, values = [], [] + keys, values = [], [] - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier + head_start = target_rank * num_heads_per_rank + head_end = head_start + num_heads_per_rank - keys.append(key_cache[current_slot_mapping].unsqueeze(0)) - values.append(value_cache[current_slot_mapping].unsqueeze(0)) + if not is_deepseek: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) + values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) + else: + key_cache = kv_cache + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(torch.empty(0)) - keys = torch.cat(keys, dim=0) - values = torch.cat(values, dim=0) + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) - self.insert(current_tokens, - torch.ones_like(current_tokens, - dtype=bool), keys, values, - hidden_or_intermediate_states[start_pos:end_pos]) + self.insert(starting_kv_group_rank, target_rank, current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) @@ -215,6 +256,7 @@ class SimpleConnector(KVConnectorBase): input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + request_ids = list(model_input.request_ids_to_seq_ids.keys()) hidden_or_intermediate_states_for_one_req = [] @@ -222,6 +264,9 @@ class SimpleConnector(KVConnectorBase): num_computed_tokens_list = [] start_pos_list = [] + model_config = model_executable.model.config + is_deepseek = "deepseek" in model_config.architectures[0].lower() + # enumerate different requests # FIXME(Kuntai): This impl assumes that all requests are prefill. for idx, slen in enumerate(seq_lens): @@ -229,13 +274,15 @@ class SimpleConnector(KVConnectorBase): start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + prefill_rank, _ = self.parse_request_id(current_request_id) num_tokens = slen # collecting data for rebuilding the input input_tokens_list.append(current_tokens) start_pos_list.append(start_pos) - ret = self.select(current_tokens, + ret = self.select(prefill_rank, current_tokens, torch.ones_like(current_tokens, dtype=bool)) if ret[0] is None: # didn't find any match. @@ -267,19 +314,25 @@ class SimpleConnector(KVConnectorBase): kv_cache = kv_caches[i - model_executable.model.start_layer] layer = model_executable.model.layers[i] - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - keys[i - model_executable.model.start_layer].to( - key_cache.device), - values[i - model_executable.model.start_layer].to( - value_cache.device), - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) + if not is_deepseek: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + else: + key_cache = kv_cache + copy_from =keys[i - model_executable.model.start_layer].to( + key_cache.device) + kv_cache[slot_mapping[start_pos:end_pos]] = copy_from hidden_or_intermediate_states_for_one_req.append(hidden) @@ -312,3 +365,77 @@ class SimpleConnector(KVConnectorBase): # MooncakePipe reuses data_pipe for signal_pipe, so we only have to # close the data_pipe. pass + + @staticmethod + def parse_request_id(request_id): + # Regular expression to match the ranks + pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + + if match: + # Extract the ranks + prefill_rank = int(match.group(1)) + decode_rank = int(match.group(2)) + + return prefill_rank, decode_rank + else: + return None, None + + + + def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: + if kv_rank < config.kv_producers_parallel_size: + return kv_rank + + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size + return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier + + def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): + if rank == 0: + if self.config.kv_connector == "PyNcclConnector": + config_group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port, + rank=self.config.kv_rank, + world_size=self.config.kv_parallel_size, + ) + parallel_configs = config_group.all_gather_obj({ + "kv_role": self.config.kv_role, + "tensor_parallel_size": config.parallel_config.tensor_parallel_size, + "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size, + }) + logger.debug("parallel_configs: %s", parallel_configs) + kv_config_enhanced = { + "kv_producers_tensor_parallel_size": None, + "kv_consumers_tensor_parallel_size": None, + "kv_producers_pipeline_parallel_size": None, + "kv_consumers_pipeline_parallel_size": None, + "kv_producers_parallel_size": 0, + } + for parallel_config in parallel_configs: + kv_role = parallel_config["kv_role"] + assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances" + + if kv_role == "kv_producer": + kv_config_enhanced["kv_producers_parallel_size"] += 1 + if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None: + kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"] + kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"] + else: + assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size" + assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size" + world_group.broadcast_object(kv_config_enhanced) + + else: + raise NotImplementedError("MooncakeConnector is not supported in Triton Distributed vllm patch") + else: + kv_config_enhanced = world_group.broadcast_object() + logger.info("kv_config_enhanced: %s", kv_config_enhanced) + + self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"] + self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"] + self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] + self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] + self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/triton_connector.py b/vllm/distributed/kv_transfer/kv_connector/triton_connector.py new file mode 100644 index 00000000..cb3b3660 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/triton_connector.py @@ -0,0 +1,350 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Simple KV Cache Connector for Distributed Machine Learning Inference + +The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache +producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or +MooncakePipe. + +But the logic can be extended to support other pipe and lookup buffer. +""" +import re +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.config import VllmConfig, KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.utils import StatelessProcessGroup +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class TritonConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + world_group, + ): + + self.config = config.kv_transfer_config + self.tp_size = config.parallel_config.tensor_parallel_size + self.rank = rank + + if self.config.kv_connector != "TritonNcclConnector": + raise NotImplementedError("Only TritonNcclConnector is supported by the TritonConnector class") + + from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( + PyNcclPipe) + from vllm.distributed.kv_transfer.kv_pipe.triton_nccl_pipe import ( + TritonNcclDataPlane) + + logger.info( + "Initializing TritonNcclConnector under kv_transfer_config %s", + self.config) + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.producer_data_pipe: PyNcclPipe + self.consumer_data_pipe: PyNcclPipe + self.producer_signal_pipe: PyNcclPipe + self.consumer_signal_pipe: PyNcclPipe + + self._broadcast_and_enhance_kv_config(rank, config, world_group) + + self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) + self.tp_size = config.parallel_config.tensor_parallel_size + + # 2 pipes for every rank in the world + if self.config.is_kv_producer: + port_offset_base = rank + 1 + else: + port_offset_base = rank // self.config.tensor_parallel_multiplier + 1 + + + self.local_kv_rank = rank % self.config.tensor_parallel_multiplier + self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config) + + self.data_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + + self.data_plane = TritonNcclDataPlane( + data_pipe=self.data_pipe, + port=self._get_data_plane_port(self.global_kv_rank), + ) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + + model_config = model_executable.model.config + is_deepseek = "deepseek" in model_config.architectures[0].lower() + if not is_deepseek: + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(hidden_size / num_attention_heads) + else: + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(4.5 * hidden_size / num_attention_heads) + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id) + decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config) + + for target_rank in range(self.config.tensor_parallel_multiplier): + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier + head_start = target_rank * num_heads_per_rank + head_end = head_start + num_heads_per_rank + + if not is_deepseek: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) + values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) + else: + key_cache = kv_cache + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(torch.empty(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + decode_global_rank = decode_first_global_rank + target_rank + decode_port = self._get_data_plane_port(decode_global_rank) + partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos] + self._send(decode_hostname, decode_port, current_request_id, keys, values, + partial_hidden_or_intermediate_states) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + start_pos_list = [] + + model_config = model_executable.model.config + is_deepseek = "deepseek" in model_config.architectures[0].lower() + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + ret = self._recv(current_request_id) + keys: torch.Tensor = ret[0] + values: torch.Tensor = ret[1] + hidden: torch.Tensor = ret[2] + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + kv_cache = kv_caches[i - model_executable.model.start_layer] + layer = model_executable.model.layers[i] + + if not is_deepseek: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + else: + key_cache = kv_cache + copy_from =keys[i - model_executable.model.start_layer].to( + key_cache.device) + kv_cache[slot_mapping[start_pos:end_pos]] = copy_from + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.debug( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def close(self): + self.data_pipe.close() + # self.data_plane.close() + + @staticmethod + def parse_request_id(request_id: str) -> Tuple[str, int]: + # Regular expression to match the string hostname and integer decode_kv_rank + pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + if match: + # Extract the ranks + decode_hostname = match.group(1) + decode_rank = int(match.group(2)) + + return decode_hostname, decode_rank + raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank") + + def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor): + remote_address = f"{hostname}:{port}" + self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address) + self.data_plane.send_tensor(values, f"{request_id}_values", remote_address) + self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address) + + def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + keys = self.data_plane.recv_tensor(f"{request_id}_keys") + values = self.data_plane.recv_tensor(f"{request_id}_values") + hidden = self.data_plane.recv_tensor(f"{request_id}_hidden") + return keys, values, hidden + + def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: + if kv_rank < config.kv_producers_parallel_size: + return kv_rank + + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size + return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier + + + def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: + if kv_rank <= config.kv_producers_parallel_size: + return kv_rank * config.kv_producers_tensor_parallel_size + rank + + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size + return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank + + + def _get_data_plane_port(self, global_kv_rank: int) -> int: + return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank + + def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): + if rank == 0: + config_group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port, + rank=self.config.kv_rank, + world_size=self.config.kv_parallel_size, + ) + parallel_configs = config_group.all_gather_obj({ + "kv_role": self.config.kv_role, + "tensor_parallel_size": config.parallel_config.tensor_parallel_size, + "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size, + }) + logger.debug("parallel_configs: %s", parallel_configs) + kv_config_enhanced = { + "kv_producers_tensor_parallel_size": None, + "kv_consumers_tensor_parallel_size": None, + "kv_producers_pipeline_parallel_size": None, + "kv_consumers_pipeline_parallel_size": None, + "kv_producers_parallel_size": 0, + } + for parallel_config in parallel_configs: + kv_role = parallel_config["kv_role"] + assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances" + + if kv_role == "kv_producer": + kv_config_enhanced["kv_producers_parallel_size"] += 1 + if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None: + kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"] + kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"] + else: + assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size" + assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size" + world_group.broadcast_object(kv_config_enhanced) + else: + kv_config_enhanced = world_group.broadcast_object() + logger.info("kv_config_enhanced: %s", kv_config_enhanced) + + self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"] + self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"] + self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] + self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] + self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 5e1b6235..b4506877 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -12,7 +12,8 @@ import threading import time from collections import deque -from typing import Deque, List, Optional, Union +from concurrent.futures import ThreadPoolExecutor +from typing import Deque, List, Optional, Union, Dict import torch @@ -46,7 +47,7 @@ class SimpleBuffer(KVLookupBufferBase): self.buffer_lock = threading.Lock() self.signal_pipe = signal_pipe self.data_pipe = data_pipe - self.request_handling_thread: Optional[threading.Thread] = None + self.request_handling_thread: Optional[ThreadPoolExecutor] = None self.normal_signal = torch.tensor([0], device="cpu") self.end_signal = None @@ -57,10 +58,16 @@ class SimpleBuffer(KVLookupBufferBase): # tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_recver: tokens and roi of the consumer (query) - tokens_sender = tokens_roi_sender[0] - tokens_recver = tokens_roi_recver[0] - roi_sender = tokens_roi_sender[1] - roi_recver = tokens_roi_recver[1] + target_rank_sender = tokens_roi_sender[0] + target_rank_recver = tokens_roi_recver[0] + + if target_rank_sender.item() != target_rank_recver.item(): + return 0 + + tokens_sender = tokens_roi_sender[1] + tokens_recver = tokens_roi_recver[1] + roi_sender = tokens_roi_sender[2] + roi_recver = tokens_roi_recver[2] if tokens_recver is None: # consumer sends an empty request @@ -80,14 +87,14 @@ class SimpleBuffer(KVLookupBufferBase): return 0 - def _send_tensor_and_dec_size(self, - tensor: Optional[torch.Tensor]) -> None: + def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor], + target_rank: int) -> None: assert tensor is not None, "Use self.data_pipe.send(None) instead" self.buffer_size -= tensor.element_size() * tensor.numel() if tensor.dtype == torch.bool: tensor = tensor.float() - self.data_pipe.send_tensor(tensor) + self.data_pipe.send_tensor(tensor, target_rank) def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): @@ -100,7 +107,7 @@ class SimpleBuffer(KVLookupBufferBase): raise AssertionError(f"Unknown data type {type(data)}") - def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + def _add_to_buffer(self, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor): @@ -115,7 +122,7 @@ class SimpleBuffer(KVLookupBufferBase): if isinstance(hidden, torch.Tensor): hidden = hidden.clone() - buffer_item = [input_tokens, roi, key, value, hidden] + buffer_item = [torch.tensor(target_rank), input_tokens, roi, key, value, hidden] with self.buffer_lock: for data in buffer_item: @@ -125,53 +132,54 @@ class SimpleBuffer(KVLookupBufferBase): def _is_end_signal(self, signal): return signal is None - def drop_select_handler(self): + def drop_select_handler(self, rank: int): try: - while True: - signal = self.signal_pipe.recv_tensor() - if self._is_end_signal(signal): - logger.info("Received end signal!") - break - - input_tokens = self.data_pipe.recv_tensor() - - roi = self.data_pipe.recv_tensor() - assert roi is not None, "Please provide the roi when sending "\ - "drop-select request" - roi = (roi > 0.5) - tokens_roi_recver = [input_tokens, roi] - - matched_length = 0 - - # perform input tokens and roi matching - # FIXME: this matching is O(n), ideally it should be O(1) - # but this buffer size won't (and shouldn't) be too large so - # the fix is not urgent. - with self.buffer_lock: - - for _ in range(len(self.buffer)): - - temp_length = self._matches(self.buffer[0], - tokens_roi_recver) - if temp_length > 0: - matched_length = temp_length - break - # rotate the element we just accessed to the end - self.buffer.rotate(-1) - - if matched_length > 0: - # need to clone the tensor - # in case the tensor is freed before sending finishes - matched_item = self.buffer.popleft() - for tensor in matched_item: - self._send_tensor_and_dec_size(tensor) - - else: - # no match, just send None - for _ in range(5): - self.data_pipe.send_tensor(None) + signal = self.signal_pipe.recv_tensor(rank) + if self._is_end_signal(signal): + logger.info("Received end signal!") + return + target_kv_rank = self.data_pipe.recv_tensor(rank) + # assert target_rank.item() == rank, "Target rank does not match"\ + # "the rank of the drop-select handler" + input_tokens = self.data_pipe.recv_tensor(rank) + roi = self.data_pipe.recv_tensor(rank) + assert roi is not None, "Please provide the roi when sending "\ + "drop-select request" + roi = (roi > 0.5) + tokens_roi_recver = [target_kv_rank, input_tokens, roi] + + matched_length = 0 + + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + with self.buffer_lock: + + for _ in range(len(self.buffer)): + + temp_length = self._matches(self.buffer[0], + tokens_roi_recver) + if temp_length > 0: + matched_length = temp_length + break + # rotate the element we just accessed to the end + self.buffer.rotate(-1) + + if matched_length > 0: + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + target_rank = matched_item[0].item() + for tensor in matched_item[1:]: + self._send_tensor_and_dec_size(tensor, rank) + + else: + # no match, just send None + for _ in range(5): + self.data_pipe.send_tensor(None, rank) except RuntimeError as e: if 'Connection closed by peer' not in str(e): @@ -180,10 +188,10 @@ class SimpleBuffer(KVLookupBufferBase): logger.debug("Closing drop_select_handler") def drop_select( - self, input_tokens: Optional[torch.Tensor], + self, rank: int, kv_rank: int, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - assert self.request_handling_thread is None, \ + assert not self.request_handling_thread, \ "drop_select should be called by the KV cache consumer "\ "(e.g. the decode vLLM instance)" @@ -192,26 +200,28 @@ class SimpleBuffer(KVLookupBufferBase): if isinstance(roi, torch.Tensor): roi = roi.clone().float() - self.signal_pipe.send_tensor(self.normal_signal) - self.data_pipe.send_tensor(input_tokens) - self.data_pipe.send_tensor(roi) + self.signal_pipe.send_tensor(self.normal_signal, rank) + + self.data_pipe.send_tensor(torch.tensor(kv_rank), rank) + self.data_pipe.send_tensor(input_tokens, rank) + self.data_pipe.send_tensor(roi, rank) - input_tokens = self.data_pipe.recv_tensor() - roi = self.data_pipe.recv_tensor() + input_tokens = self.data_pipe.recv_tensor(rank) + roi = self.data_pipe.recv_tensor(rank) if roi is not None: # convert from float tensor to bool tensor # as PyNccl does not support sending bool tensor roi = (roi > 0.5) - key = self.data_pipe.recv_tensor() - value = self.data_pipe.recv_tensor() - hidden = self.data_pipe.recv_tensor() + key = self.data_pipe.recv_tensor(rank) + value = self.data_pipe.recv_tensor(rank) + hidden = self.data_pipe.recv_tensor(rank) return [input_tokens, roi, key, value, hidden] def full_handler(self): time.sleep(0.001) - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: @@ -222,20 +232,19 @@ class SimpleBuffer(KVLookupBufferBase): while self.buffer_size > self.buffer_size_threshold: self.full_handler() - self._add_to_buffer(input_tokens, roi, key, value, hidden) + self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden) # when calling the insert, the current process is a sender # need to launch the request handler and start listening to request. + target_rank_global = target_rank + kv_group_rank if self.request_handling_thread is None: - self.request_handling_thread = threading.Thread( - target=self.drop_select_handler) - self.request_handling_thread.start() + self.request_handling_thread = ThreadPoolExecutor(max_workers=1) + self.request_handling_thread.submit(self.drop_select_handler, target_rank_global) def close(self): - if hasattr(self, "request_handling_thread" - ) and self.request_handling_thread is not None: - self.request_handling_thread.join() + if hasattr(self, "request_handling_thread") and self.request_handling_thread: + self.request_handling_thread.shutdown() else: # TODO: have a explicit close signal and have a explicit way to diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py index 40589fb3..da2829cf 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -23,7 +23,7 @@ class KVPipeBase(ABC): """ @abstractmethod - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None: """Send a tensor, or None, via the pipe. Need to support sending None -- important for error handling. @@ -41,7 +41,7 @@ class KVPipeBase(ABC): raise NotImplementedError @abstractmethod - def recv_tensor(self) -> Optional[torch.Tensor]: + def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: """Receive a tensor (can be None) from the pipeline. Returns: diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index 7aa53d07..f5dd50b7 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -45,33 +45,33 @@ class PyNcclPipe(KVPipeBase): METADATA_DTYPE = torch.int64 def __init__(self, + kv_group_rank: int, local_rank: int, config: KVTransferConfig, device: Optional[str] = None, port_offset: int = 0): self.config = config self.local_rank = local_rank - self.kv_rank = self.config.kv_rank + self.kv_group_rank = kv_group_rank self.kv_parallel_size = self.config.kv_parallel_size + self.kv_world_size = self.config.kv_world_size if device is None: self.device = self._select_device(self.config.kv_buffer_device) else: self.device = self._select_device(device) # build distributed connection and send/recv implementation + logger.info("Creating process group for kv transfer with rank %d and world size %d, ip: %s, port: %d", self.kv_group_rank, self.kv_world_size, self.config.kv_ip, self.config.kv_port + port_offset) self.group = StatelessProcessGroup.create( host=self.config.kv_ip, port=self.config.kv_port + port_offset, - rank=self.kv_rank, - world_size=self.kv_parallel_size, + rank=self.kv_group_rank, + world_size=self.kv_world_size, ) # add a barrier to make sure the connection is initiated properly self.group.barrier() impl = self._get_device_send_recv_impl(self.group) self.device_send_func, self.device_recv_func = impl - # set target rank - self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size - self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size # transportation-related variables self.transport_thread: Optional[ThreadPoolExecutor] = None @@ -145,16 +145,16 @@ class PyNcclPipe(KVPipeBase): dtype=metadata["dtype"], device=self.device) - def _send_metadata(self, metadata: Metadata): + def _send_metadata(self, metadata: Metadata, target_rank: int): """ Send the metadata dictionary to the target rank. Parameters: - metadata: A dictionary with keys "dtype" and "shape". """ - self.group.send_obj(metadata, self.target_rank_for_send) + self.group.send_obj(metadata, target_rank) - def _recv_metadata(self) -> Metadata: + def _recv_metadata(self, src_rank: int) -> Metadata: """ Receive the metadata dictionary from the target rank. @@ -162,9 +162,9 @@ class PyNcclPipe(KVPipeBase): - metadata: A dictionary with keys "dtype" and "shape" describing the tensor. """ - return self.group.recv_obj(self.target_rank_for_recv) + return self.group.recv_obj(src_rank) - def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: + def _send_impl(self, tensor: Optional[torch.Tensor], target_rank: int) -> None: """ The actual implementation of sending the tensor and its metadata to the target rank. @@ -174,12 +174,12 @@ class PyNcclPipe(KVPipeBase): being sent. """ metadata = self._make_metadata(tensor) - self._send_metadata(metadata) + self._send_metadata(metadata, target_rank) if tensor is not None: self.device_send_func(tensor.to(self.device), - self.target_rank_for_send) + target_rank) - def _recv_impl(self) -> Optional[torch.Tensor]: + def _recv_impl(self, src_rank: int) -> Optional[torch.Tensor]: """ The actual implementation of receiving a tensor and its metadata from the target rank. @@ -187,21 +187,22 @@ class PyNcclPipe(KVPipeBase): Returns: - buffer: The received tensor, or None if no tensor is received. """ - metadata = self._recv_metadata() + metadata = self._recv_metadata(src_rank) if metadata["dtype"] is None: return None buffer = self._prepare_recv_buffer(metadata) - self.device_recv_func(buffer, self.target_rank_for_recv) + self.device_recv_func(buffer, src_rank) return buffer def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], - tensor_size: int) -> None: + tensor_size: int, + target_rank: int) -> None: """ Wrapper for _send_impl to handle exceptions and update buffer size. """ try: - self._send_impl(tensor) + self._send_impl(tensor, target_rank) with self.buffer_size_lock: self.buffer_size -= tensor_size @@ -220,7 +221,7 @@ class PyNcclPipe(KVPipeBase): logger.debug("KV cache transfer pipe is full. Waiting...") time.sleep(0.05) - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int) -> None: """ Sends a tensor and its metadata to the destination rank in a non-blocking way. @@ -228,6 +229,7 @@ class PyNcclPipe(KVPipeBase): Parameters: - tensor: The tensor to send, or None if no tensor is being sent. """ + logger.debug("Rank %d sending tensor of shape %s dtype %s to rank %d", self.kv_group_rank, tensor.shape if tensor is not None else "None", tensor.dtype if tensor is not None else "None", target_rank) if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) @@ -241,32 +243,39 @@ class PyNcclPipe(KVPipeBase): with self.buffer_size_lock: self.buffer_size += tensor_size - self.transport_thread.submit(self.send_tensor_wrapper, tensor, - tensor_size) + future = self.transport_thread.submit(self.send_tensor_wrapper, tensor, + tensor_size, + target_rank) + return future - def recv_tensor(self) -> Optional[torch.Tensor]: + def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: """ Receives a tensor and its metadata from the source rank. Blocking call. Returns: - tensor: The received tensor, or None if no tensor is received. """ + + logger.debug("Rank %d receiving tensor from rank %d", self.kv_group_rank, src_rank) + if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) - future = self.transport_thread.submit(self._recv_impl) + future = self.transport_thread.submit(self._recv_impl, src_rank) - try: - tensor = future.result() - except Exception as e: - logger.error("Encountering exception in KV receiving thread") - logger.error("%s", e) - logger.error("My device: %s", self.device) - import traceback - traceback.print_exc() - raise e + return future + + # try: + # tensor = future.result() + # except Exception as e: + # logger.error("Encountering exception in KV receiving thread") + # logger.error("%s", e) + # logger.error("My device: %s", self.device) + # import traceback + # traceback.print_exc() + # raise e - return tensor + # return tensor def close(self): """ diff --git a/vllm/distributed/kv_transfer/kv_pipe/triton_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/triton_nccl_pipe.py new file mode 100644 index 00000000..8a356504 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/triton_nccl_pipe.py @@ -0,0 +1,124 @@ +import logging +import threading +import typing +import zmq +import socket +import time +import torch + +from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe + + +logger = logging.getLogger(__name__) + + +class TritonNcclDataPlane: + def __init__( + self, + data_pipe: PyNcclPipe, + hostname: str = "", + port: int = 0, + ) -> None: + + self.data_pipe = data_pipe + if not hostname: + hostname = socket.gethostname() + if port == 0: + raise ValueError("Port cannot be 0") + self._hostname = hostname + self._port = port + self.store = {} + self.context = zmq.Context() + self.rep_socket = self.context.socket(zmq.REP) + logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}") + self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}") + self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True) + self._listener_thread.start() + self.req_sockets = {} + logger.info(f"Rank {self.rank} connected to the server") + + @property + def rank(self): + return self.data_pipe.kv_group_rank + + def send_tensor( + self, + tensor: torch.Tensor, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ): + logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}") + return self._send_tensor(tensor, tensor_id, remote_address) + + def recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + ret = self._recv_tensor(tensor_id, remote_address) + return ret + + def _send_tensor( + self, + tensor: torch.Tensor, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ): + logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}") + if remote_address is None: + self.store[tensor_id] = tensor + else: + # tensor_shape = "_".join(str(dim) for dim in tensor.shape) + # tensor_dtype = str(tensor.dtype) + if remote_address not in self.req_sockets: + self.req_sockets[remote_address] = self.context.socket(zmq.REQ) + self.req_sockets[remote_address].connect(f"tcp://{remote_address}") + + req_socket = self.req_sockets[remote_address] + # req_socket.connect(f"tcp://{remote_address}") + req_socket.send_string(f"PUT {self.rank} {tensor_id}") + dst_rank = req_socket.recv_string() + logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}") + self.data_pipe.send_tensor(tensor, int(dst_rank)) + + def _recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + logger.debug(f"Rank {self.rank} receiving tensor") + if remote_address is not None: + raise NotImplementedError("Getting tensor from remote rank not implemented") + if tensor_id in self.store: + logger.debug(f"Popping tensor {tensor_id} from store") + future = self.store.pop(tensor_id) + tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait + logger.debug(f"Rank {self.rank} received tensor") + return tensor + + logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}") + time.sleep(0.001) + return self._recv_tensor(tensor_id, remote_address) + # raise NotImplementedError("Tensor not found in store") + + def _receive_tensor( + self, + tensor_id: str, + rank: int, + ): + future = self.data_pipe.recv_tensor(rank) + logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store") + self.store[tensor_id] = future + + def listen_for_requests(self): + while True: + cmd, rank, tensor_id = self.rep_socket.recv_string().split() + logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}") + self.rep_socket.send_string(f"{self.rank}") + if cmd == "GET": + raise NotImplementedError("Getting tensor from remote rank not implemented") + elif cmd == "PUT": + rank = int(rank) + # shape = [int(dim) for dim in shape.split("_")] + # dtype = getattr(torch, dtype) + self._receive_tensor(tensor_id, rank) diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py index 1e80e0bd..cd90206f 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -35,6 +35,7 @@ class KVTransferAgent: rank: int, local_rank: int, config: "VllmConfig", + world_group, ): self.config = config @@ -47,7 +48,7 @@ class KVTransferAgent: "TransferAgent should only be used when kv_connector is set." self.connector = KVConnectorFactory.create_connector( - rank, local_rank, config) + rank, local_rank, config, world_group) def send_kv_caches_and_hidden_states( self, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 321902d1..b8937ef8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1085,7 +1085,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, - config=vllm_config) + config=vllm_config, + world_group=get_world_group()) def ensure_model_parallel_initialized( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d82d9ad9..542ccfe8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -348,7 +348,7 @@ class LLMEngine: # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = [ Scheduler( - self.scheduler_config, self.cache_config, self.lora_config, + self.model_config, self.scheduler_config, self.cache_config, self.lora_config, self.parallel_config.pipeline_parallel_size, self.async_callbacks[v_id] if self.model_config.use_async_output_proc else None) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 3cf1850e..38acca0e 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -21,6 +21,7 @@ IPC_INPUT_EXT = "_input_socket" IPC_OUTPUT_EXT = "_output_socket" IPC_HEALTH_EXT = "_health_socket" IPC_DATA_EXT = "_data_socket" +IPC_METRICS_EXT = "_metrics_socket" class MQEngineDeadError(RuntimeError): @@ -157,3 +158,10 @@ def ENGINE_DEAD_ERROR( return MQEngineDeadError( "Engine loop is not running. Inspect the stacktrace to " f"find the original error: {repr(error)}.") + +@dataclass +class KvMetrics: + request_active_slots: int + request_total_slots: int + kv_active_blocks: int + kv_total_blocks: int diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 85b5f31e..6a7ea3ae 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -25,14 +25,15 @@ from vllm.engine.async_llm_engine import ( build_guided_decoding_logits_processor_async) from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, RPC_REQUEST_T, + IPC_OUTPUT_EXT, IPC_METRICS_EXT, + RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCAdapterLoadedResponse, RPCError, RPCLoadAdapterRequest, RPCProcessRequest, RPCResetPrefixCacheRequest, RPCStartupRequest, RPCStartupResponse, - RPCUProfileRequest) + RPCUProfileRequest, KvMetrics) from vllm.engine.protocol import EngineClient # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT @@ -115,6 +116,10 @@ class MQLLMEngineClient(EngineClient): self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + # Metrics. + self.metrics_socket: Socket = self.context.socket(zmq.constants.PULL) + self.metrics_socket.connect(f"{ipc_path}{IPC_METRICS_EXT}") + # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -129,6 +134,12 @@ class MQLLMEngineClient(EngineClient): # Loop to check health of the LLMEngine periodically. # Started after the MQLLMEngine is ready. self.health_loop: Optional[asyncio.Task] = None + + # Loop to check metrics of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. + self.metrics_loop: Optional[asyncio.Task] = None + self.metrics_publisher = None + self._engine_process = psutil.Process(engine_pid) @staticmethod @@ -180,6 +191,46 @@ class MQLLMEngineClient(EngineClient): except Exception as e: self._set_errored(e) + async def run_metrics_loop(self, timeout: int): + """Background loop that continually checks to ensure the engine process + is still alive. + """ + try: + while True: + # Check if the engine process is running: + if not self._engine_process.is_running() or ( + self._engine_process.status() == psutil.STATUS_ZOMBIE): + # NB: is_running() returns True for zombies + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) " + "died.")) + break + + if await self.metrics_socket.poll(timeout=timeout): + # Metrics received- check the message + message: Frame = await self.metrics_socket.recv(copy=False) + kv_metrics = pickle.loads(message.buffer) + if self.metrics_publisher is not None: + if isinstance(kv_metrics, KvMetrics): + self.metrics_publisher.publish(kv_metrics.request_active_slots, + kv_metrics.request_total_slots, + kv_metrics.kv_active_blocks, + kv_metrics.kv_total_blocks) + + logger.debug("Metircs successful.") + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient check metrics loop.") + + except psutil.NoSuchProcess: + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) died.")) + + except Exception as e: + self._set_errored(e) + async def run_output_handler_loop(self): """Get RequestOutputs from Engine and stream to Request Queues""" @@ -284,6 +335,12 @@ class MQLLMEngineClient(EngineClient): if self.health_loop is None: self.health_loop = asyncio.create_task( self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) + + # Start metrics_loop. + if self.metrics_loop is None: + self.metrics_loop = asyncio.create_task( + self.run_metrics_loop(timeout=VLLM_RPC_TIMEOUT)) + def close(self): """Destroy the ZeroMQ Context.""" @@ -293,6 +350,8 @@ class MQLLMEngineClient(EngineClient): # Cancel background tasks. if self.health_loop is not None: self.health_loop.cancel() + if self.metrics_loop is not None: + self.metrics_loop.cancel() if self.output_loop is not None: self.output_loop.cancel() @@ -705,3 +764,6 @@ class MQLLMEngineClient(EngineClient): # Raise on error, otherwise happily return None if isinstance(request_output, BaseException): raise request_output + + def set_metrics_publisher(self, metrics_publisher): + self.metrics_publisher = metrics_publisher diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index a0dd7958..dc6ea25d 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -14,24 +14,56 @@ from vllm.engine.llm_engine import LLMEngine # yapf: disable from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, + IPC_OUTPUT_EXT, IPC_METRICS_EXT, + REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCAdapterLoadedResponse, RPCError, RPCLoadAdapterRequest, RPCProcessRequest, RPCResetPrefixCacheRequest, RPCStartupRequest, RPCStartupResponse, - RPCUProfileRequest) + RPCUProfileRequest, KvMetrics) # yapf: enable from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext +from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo +from dataclasses import dataclass, field logger = init_logger(__name__) POLLING_TIMEOUT_MS = 10000 HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) +class KvStatLogger(StatLoggerBase): + def __init__( + self, + max_num_seqs: int, + num_total_gpu_blocks: int, + metrics_socket + ): + # Must query initialized scheduler for max infos + self.request_total_slots = max_num_seqs + self.kv_total_blocks = num_total_gpu_blocks + self.metrics_socket = metrics_socket + + # KV metrics + self._send_kv_metrics(0, 0) + + def log(self, stats: Stats) -> None: + self._send_kv_metrics( + stats.num_running_sys, + int(stats.gpu_cache_usage_sys * self.kv_total_blocks) + ) + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + pass + + def _send_kv_metrics(self, active_slots, active_kv_blocks): + if not self.metrics_socket.closed: + metrics_bytes = pickle.dumps(KvMetrics(active_slots, self.request_total_slots, active_kv_blocks, self.kv_total_blocks)) + self.metrics_socket.send_multipart((metrics_bytes, ), copy=False) + class MQLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. @@ -94,12 +126,24 @@ class MQLLMEngine: self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + # Send metrics back to client. + self.metrics_socket = self.ctx.socket(zmq.constants.PUSH) + self.metrics_socket.bind(f"{ipc_path}{IPC_METRICS_EXT}") + # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" # Error state. self._errored_with: Optional[BaseException] = None + # Attach logger for continuous metrics publishing + self.stat_logger = KvStatLogger( + self.engine.scheduler_config.max_num_seqs, + self.engine.cache_config.num_gpu_blocks, + self.metrics_socket + ) + self.engine.add_logger("kv_metrics", self.stat_logger) + @property def dead_error(self) -> BaseException: if self._errored_with is not None: diff --git a/vllm/envs.py b/vllm/envs.py index 745b068b..0ae63d9b 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -87,6 +87,10 @@ if TYPE_CHECKING: VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" + VLLM_KV_CAPI_PATH: Optional[str] = None + VLLM_KV_NAMESPACE: Optional[str] = None + VLLM_KV_COMPONENT: Optional[str] = None + VLLM_WORKER_ID: Optional[int] = None def get_default_cache_root(): @@ -572,6 +576,21 @@ environment_variables: Dict[str, Callable[[], Any]] = { # models the alignment is already naturally aligned to 256 bytes. "VLLM_CUDA_MEM_ALIGN_KV_CACHE": lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))), + + # Path to the C API Library + "VLLM_KV_CAPI_PATH": + lambda: os.environ.get("VLLM_KV_CAPI_PATH", None), + + # Identifiers to publish KV related information + "VLLM_KV_NAMESPACE": + lambda: os.environ.get("VLLM_KV_NAMESPACE", None), + "VLLM_KV_COMPONENT": + lambda: os.environ.get("VLLM_KV_COMPONENT", None), + + # Worker ID used for identifying workers in distributed settings + "VLLM_WORKER_ID": + lambda: int(os.getenv("VLLM_WORKER_ID", "0")) + if "VLLM_WORKER_ID" in os.environ else None, } # end-env-vars-definition diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 773f5abe..3eefd266 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -585,6 +585,8 @@ class DeepseekV2Model(nn.Module): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + self.config = config + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size