"vscode:/vscode.git/clone" did not exist on "0333b7a3b7ffdfbf48db6290468198a238e31a8b"
Commit 861c5098 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files
parent eb022ec9
...@@ -181,36 +181,45 @@ index 1ca9e49d..b1591c0c 100644 ...@@ -181,36 +181,45 @@ index 1ca9e49d..b1591c0c 100644
# Reuse the cached content hash # Reuse the cached content hash
diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py
index c5b3b04f..8a483aa2 100644 index c5b3b04f..c72001f7 100644
--- a/vllm/core/block_manager.py --- a/vllm/core/block_manager.py
+++ b/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py
@@ -9,10 +9,12 @@ from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator @@ -10,7 +10,10 @@ from vllm.core.block.interfaces import Block
from vllm.core.block.interfaces import Block
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
LastAccessBlocksTracker) LastAccessBlocksTracker)
+from vllm.core.event_manager import KVCacheEventManager
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
+from vllm.core.event_manager import KVCacheEventManager
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
+from vllm.envs import (VLLM_KV_CAPI_PATH, VLLM_KV_COMPONENT, VLLM_KV_NAMESPACE,
+ VLLM_WORKER_ID)
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device from vllm.utils import Device
+from vllm.envs import VLLM_WORKER_ID, VLLM_KV_CAPI_PATH
SeqId = int @@ -60,6 +63,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
EncoderSeqId = str
@@ -60,6 +62,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
def __init__( def __init__(
self, self,
+ model_name: str, + model_name: str,
block_size: int, block_size: int,
num_gpu_blocks: int, num_gpu_blocks: int,
num_cpu_blocks: int, num_cpu_blocks: int,
@@ -91,11 +94,17 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): @@ -91,11 +95,28 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.watermark_blocks = int(watermark * num_gpu_blocks) self.watermark_blocks = int(watermark * num_gpu_blocks)
+ if VLLM_WORKER_ID is not None and VLLM_KV_CAPI_PATH is not None: + kv_event_manager_params = [
+ self.event_manager = KVCacheEventManager(model_name, worker_id=str(VLLM_WORKER_ID).encode(), lib_path=VLLM_KV_CAPI_PATH) + VLLM_WORKER_ID, VLLM_KV_CAPI_PATH, VLLM_KV_NAMESPACE,
+ VLLM_KV_COMPONENT
+ ]
+ set_kv_event_manager_params = len(
+ [param for param in kv_event_manager_params if param is not None])
+
+ if set_kv_event_manager_params == len(kv_event_manager_params):
+ self.event_manager = KVCacheEventManager(
+ namespace=VLLM_KV_NAMESPACE,
+ component=VLLM_KV_COMPONENT,
+ worker_id=VLLM_WORKER_ID,
+ lib_path=VLLM_KV_CAPI_PATH)
+ else: + else:
+ self.event_manager = None + self.event_manager = None
+ +
...@@ -225,84 +234,96 @@ index c5b3b04f..8a483aa2 100644 ...@@ -225,84 +234,96 @@ index c5b3b04f..8a483aa2 100644
self.block_tables: Dict[SeqId, BlockTable] = {} self.block_tables: Dict[SeqId, BlockTable] = {}
diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py
new file mode 100644 new file mode 100644
index 00000000..4aa90a4a index 00000000..350453cd
--- /dev/null --- /dev/null
+++ b/vllm/core/event_manager.py +++ b/vllm/core/event_manager.py
@@ -0,0 +1,89 @@ @@ -0,0 +1,102 @@
+from typing import Optional +# SPDX-License-Identifier: Apache-2.0
+import logging
+from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash
+
+import ctypes +import ctypes
+from ctypes import c_char_p, c_uint32, c_void_p, c_size_t +import logging
+import uuid +import uuid
+from ctypes import c_char_p, c_size_t, c_uint32, c_void_p, c_int64
+from typing import Optional
+
+from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash
+ +
+logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__)
+ +
+
+class TritonResult: +class TritonResult:
+ OK = 0 + OK = 0
+ ERR = 1 + ERR = 1
+ +
+
+class KVCacheEventManager: +class KVCacheEventManager:
+ def __init__(self, model_name: str, worker_id: bytes, lib_path: str): +
+ def __init__(self, namespace: str, component: str, worker_id: int,
+ lib_path: str):
+ self.lib = None + self.lib = None
+ +
+ try: + try:
+ self.lib = ctypes.CDLL(lib_path) + self.lib = ctypes.CDLL(lib_path)
+ self.lib.triton_llm_init.argtypes = [c_char_p, c_char_p] + self.lib.triton_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
+ self.lib.triton_llm_init.restype = c_uint32 + self.lib.triton_llm_init.restype = c_uint32
+ +
+ result = self.lib.triton_llm_init(model_name.encode(), worker_id) + result = self.lib.triton_llm_init(namespace.encode(),
+ component.encode(), worker_id)
+ if result == TritonResult.OK: + if result == TritonResult.OK:
+ logger.info("KVCacheEventManager initialized successfully. Ready to publish KV Cache Events") + logger.info(
+ "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
+ )
+ else: + else:
+ logger.info("KVCacheEventManager initialization failed!") + logger.info("KVCacheEventManager initialization failed!")
+ +
+ except Exception as e: + except Exception as e:
+ print(f"Failed to load {lib_path}") + print(f"Failed to load {lib_path}")
+ raise e + raise e
+ +
+ self.lib.triton_kv_event_publish_stored.argtypes = [ + self.lib.triton_kv_event_publish_stored.argtypes = [
+ ctypes.c_uint64, # event_id + ctypes.c_uint64, # event_id
+ ctypes.POINTER(ctypes.c_uint32), # token_ids + ctypes.POINTER(ctypes.c_uint32), # token_ids
+ ctypes.POINTER(ctypes.c_size_t), # num_block_tokens + ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
+ ctypes.POINTER(ctypes.c_uint64), # block_ids + ctypes.POINTER(ctypes.c_uint64), # block_ids
+ ctypes.c_size_t, # num_blocks + ctypes.c_size_t, # num_blocks
+ ctypes.POINTER(ctypes.c_uint64), # parent_hash + ctypes.POINTER(ctypes.c_uint64), # parent_hash
+ ctypes.c_uint64, # lora_id + ctypes.c_uint64, # lora_id
+ ] + ]
+ self.lib.triton_kv_event_publish_stored.restype = ctypes.c_uint32 # triton_llm_result_t + self.lib.triton_kv_event_publish_stored.restype = ctypes.c_uint32 # triton_llm_result_t
+ +
+ self.lib.triton_kv_event_publish_removed.argtypes = [ + self.lib.triton_kv_event_publish_removed.argtypes = [
+ ctypes.c_uint64, # event_id + ctypes.c_uint64, # event_id
+ ctypes.POINTER(ctypes.c_uint64), # block_ids + ctypes.POINTER(ctypes.c_uint64), # block_ids
+ ctypes.c_size_t, # num_blocks + ctypes.c_size_t, # num_blocks
+ ] + ]
+ self.lib.triton_kv_event_publish_removed.restype = ctypes.c_uint32 # triton_llm_result_t + self.lib.triton_kv_event_publish_removed.restype = ctypes.c_uint32 # triton_llm_result_t
+ +
+ self.event_id_counter = 0 + self.event_id_counter = 0
+ +
+ def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock], block: PrefixCachingBlock): + def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock],
+ token_ids_arr = (ctypes.c_uint32 * len(block.token_ids))(*block.token_ids) + block: PrefixCachingBlock):
+ token_ids_arr = (ctypes.c_uint32 *
+ len(block.token_ids))(*block.token_ids)
+ num_block_tokens = (ctypes.c_size_t * 1)(len(block.token_ids)) + num_block_tokens = (ctypes.c_size_t * 1)(len(block.token_ids))
+ block_hash = (ctypes.c_uint64 * 1)(block.content_hash) + block_hash = (ctypes.c_uint64 * 1)(block.content_hash)
+ parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash) if parent is not None else None) + parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash)
+ if parent is not None else None)
+ +
+ # Publish the event + # Publish the event
+ result = self.lib.triton_kv_event_publish_stored( + result = self.lib.triton_kv_event_publish_stored(
+ self.event_id_counter, # uint64_t event_id + self.event_id_counter, # uint64_t event_id
+ token_ids_arr, # const uint32_t *token_ids + token_ids_arr, # const uint32_t *token_ids
+ num_block_tokens, # const uintptr_t *num_block_tokens + num_block_tokens, # const uintptr_t *num_block_tokens
+ block_hash, # const uint64_t *block_ids + block_hash, # const uint64_t *block_ids
+ 1, # uintptr_t num_blocks + 1, # uintptr_t num_blocks
+ parent_hash, # const uint64_t *parent_hash + parent_hash, # const uint64_t *parent_hash
+ 0, # uint64_t lora_id + 0, # uint64_t lora_id
+ ) + )
+ +
+ if result == TritonResult.OK: + if result == TritonResult.OK:
+ logger.debug(f"Store - Published KV Event: {block.content_hash}") + logger.debug(f"Store - Published KV Event: {block.content_hash}")
+ else: + else:
+ logger.debug(f"Store - Failed to Publish KV Event: {block.content_hash}") + logger.debug(
+ f"Store - Failed to Publish KV Event: {block.content_hash}")
+ +
+ self.event_id_counter += 1 + self.event_id_counter += 1
+ +
...@@ -310,15 +331,15 @@ index 00000000..4aa90a4a ...@@ -310,15 +331,15 @@ index 00000000..4aa90a4a
+ result = self.lib.triton_kv_event_publish_removed( + result = self.lib.triton_kv_event_publish_removed(
+ self.event_id_counter, + self.event_id_counter,
+ (ctypes.c_uint64 * 1)(block_hash), + (ctypes.c_uint64 * 1)(block_hash),
+ 1,) + 1,
+ + )
+
+ if result == TritonResult.OK: + if result == TritonResult.OK:
+ logger.debug(f"Remove - Published KV Event: {block_hash}") + logger.debug(f"Remove - Published KV Event: {block_hash}")
+ else: + else:
+ logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}") + logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}")
+ +
+ self.event_id_counter += 1 + self.event_id_counter += 1
\ No newline at end of file
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..6af77646 100644 index f507847a..6af77646 100644
--- a/vllm/core/scheduler.py --- a/vllm/core/scheduler.py
...@@ -356,7 +377,7 @@ index f507847a..6af77646 100644 ...@@ -356,7 +377,7 @@ index f507847a..6af77646 100644
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..61a357d0 100644 index fe480533..b768e03c 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py --- a/vllm/distributed/kv_transfer/kv_connector/factory.py
+++ b/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py
@@ -27,13 +27,13 @@ class KVConnectorFactory: @@ -27,13 +27,13 @@ class KVConnectorFactory:
...@@ -375,16 +396,6 @@ index fe480533..61a357d0 100644 ...@@ -375,16 +396,6 @@ index fe480533..61a357d0 100644
# Register various connectors here. # Register various connectors here.
@@ -48,3 +48,8 @@ KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
+
+KVConnectorFactory.register_connector(
+ "TritonNcclConnector",
+ "vllm.distributed.kv_transfer.kv_connector.triton_connector",
+ "TritonConnector")
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
index 2033e976..e33919c1 100644 index 2033e976..e33919c1 100644
--- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
...@@ -724,363 +735,6 @@ index 2033e976..e33919c1 100644 ...@@ -724,363 +735,6 @@ index 2033e976..e33919c1 100644
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] + self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] + self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\ No newline at end of file \ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/triton_connector.py b/vllm/distributed/kv_transfer/kv_connector/triton_connector.py
new file mode 100644
index 00000000..cb3b3660
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_connector/triton_connector.py
@@ -0,0 +1,350 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Simple KV Cache Connector for Distributed Machine Learning Inference
+
+The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
+producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
+MooncakePipe.
+
+But the logic can be extended to support other pipe and lookup buffer.
+"""
+import re
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
+import torch
+
+from vllm import _custom_ops as ops
+from vllm.config import VllmConfig, KVTransferConfig
+from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
+from vllm.distributed.utils import StatelessProcessGroup
+from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
+ SimpleBuffer)
+from vllm.logger import init_logger
+from vllm.sequence import IntermediateTensors
+
+if TYPE_CHECKING:
+ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
+
+logger = init_logger(__name__)
+
+
+class TritonConnector(KVConnectorBase):
+
+ def __init__(
+ self,
+ rank: int,
+ local_rank: int,
+ config: VllmConfig,
+ world_group,
+ ):
+
+ self.config = config.kv_transfer_config
+ self.tp_size = config.parallel_config.tensor_parallel_size
+ self.rank = rank
+
+ if self.config.kv_connector != "TritonNcclConnector":
+ raise NotImplementedError("Only TritonNcclConnector is supported by the TritonConnector class")
+
+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
+ PyNcclPipe)
+ from vllm.distributed.kv_transfer.kv_pipe.triton_nccl_pipe import (
+ TritonNcclDataPlane)
+
+ logger.info(
+ "Initializing TritonNcclConnector under kv_transfer_config %s",
+ self.config)
+
+ self.lookup_buffer_size = self.config.kv_buffer_size
+
+ self.producer_data_pipe: PyNcclPipe
+ self.consumer_data_pipe: PyNcclPipe
+ self.producer_signal_pipe: PyNcclPipe
+ self.consumer_signal_pipe: PyNcclPipe
+
+ self._broadcast_and_enhance_kv_config(rank, config, world_group)
+
+ self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
+ self.tp_size = config.parallel_config.tensor_parallel_size
+
+ # 2 pipes for every rank in the world
+ if self.config.is_kv_producer:
+ port_offset_base = rank + 1
+ else:
+ port_offset_base = rank // self.config.tensor_parallel_multiplier + 1
+
+
+ self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
+ self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config)
+
+ self.data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
+ local_rank=local_rank,
+ config=self.config,
+ port_offset=port_offset_base,
+ )
+
+ self.data_plane = TritonNcclDataPlane(
+ data_pipe=self.data_pipe,
+ port=self._get_data_plane_port(self.global_kv_rank),
+ )
+
+ def send_kv_caches_and_hidden_states(
+ self,
+ model_executable: torch.nn.Module,
+ model_input: "ModelInputForGPUWithSamplingMetadata",
+ kv_caches: List[torch.Tensor],
+ hidden_or_intermediate_states: Union[torch.Tensor,
+ IntermediateTensors],
+ ) -> None:
+
+ input_tokens_tensor = model_input.input_tokens
+ seq_lens = model_input.attn_metadata.seq_lens
+ slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
+ start_layer = model_executable.model.start_layer
+ end_layer = model_executable.model.end_layer
+ request_ids = list(model_input.request_ids_to_seq_ids.keys())
+
+ model_config = model_executable.model.config
+ is_deepseek = "deepseek" in model_config.architectures[0].lower()
+ if not is_deepseek:
+ num_heads = int(model_config.num_key_value_heads / self.tp_size)
+ hidden_size = model_config.hidden_size
+ num_attention_heads = model_config.num_attention_heads
+ head_size = int(hidden_size / num_attention_heads)
+ else:
+ num_heads = int(model_config.num_key_value_heads / self.tp_size)
+ hidden_size = model_config.hidden_size
+ num_attention_heads = model_config.num_attention_heads
+ head_size = int(4.5 * hidden_size / num_attention_heads)
+
+ # query_lens contains new KV caches that are added to vLLM.
+ # so we will send them to decode instance
+ # FIXME(Kuntai): This assume that all requests are prefill.
+ for idx, slen in enumerate(seq_lens):
+ start_pos = sum(seq_lens[:idx])
+ end_pos = start_pos + slen
+ current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx]
+ decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id)
+ decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config)
+
+ for target_rank in range(self.config.tensor_parallel_multiplier):
+
+ keys, values = [], []
+
+ for layer_id in range(start_layer, end_layer):
+ kv_cache = kv_caches[layer_id - start_layer]
+
+ current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
+
+ num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
+ head_start = target_rank * num_heads_per_rank
+ head_end = head_start + num_heads_per_rank
+
+ if not is_deepseek:
+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
+ keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
+ values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
+ else:
+ key_cache = kv_cache
+ keys.append(key_cache[current_slot_mapping].unsqueeze(0))
+ values.append(torch.empty(0))
+
+ keys = torch.cat(keys, dim=0)
+ values = torch.cat(values, dim=0)
+
+ decode_global_rank = decode_first_global_rank + target_rank
+ decode_port = self._get_data_plane_port(decode_global_rank)
+ partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos]
+ self._send(decode_hostname, decode_port, current_request_id, keys, values,
+ partial_hidden_or_intermediate_states)
+
+ logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
+
+ def recv_kv_caches_and_hidden_states(
+ self, model_executable: torch.nn.Module,
+ model_input: "ModelInputForGPUWithSamplingMetadata",
+ kv_caches: List[torch.Tensor]
+ ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
+ "ModelInputForGPUWithSamplingMetadata"]:
+
+ # When bypass_model_exec is set to False, it means that at least for one
+ # request its corresponding KV cache or hidden state is missing.
+ # In this case we need to do prefilling to recompute missing KV cache
+ # and hidden states.
+ bypass_model_exec = True
+
+ input_tokens_tensor = model_input.input_tokens
+ seq_lens = model_input.attn_metadata.seq_lens
+ slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
+ request_ids = list(model_input.request_ids_to_seq_ids.keys())
+
+ hidden_or_intermediate_states_for_one_req = []
+
+ input_tokens_list = []
+ start_pos_list = []
+
+ model_config = model_executable.model.config
+ is_deepseek = "deepseek" in model_config.architectures[0].lower()
+
+ # enumerate different requests
+ # FIXME(Kuntai): This impl assumes that all requests are prefill.
+ for idx, slen in enumerate(seq_lens):
+
+ start_pos = sum(seq_lens[:idx])
+ end_pos = start_pos + slen
+ current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx]
+ num_tokens = slen
+
+ # collecting data for rebuilding the input
+ input_tokens_list.append(current_tokens)
+ start_pos_list.append(start_pos)
+
+ ret = self._recv(current_request_id)
+ keys: torch.Tensor = ret[0]
+ values: torch.Tensor = ret[1]
+ hidden: torch.Tensor = ret[2]
+
+ # put received KV caches into paged memory
+ for i in range(model_executable.model.start_layer,
+ model_executable.model.end_layer):
+
+ kv_cache = kv_caches[i - model_executable.model.start_layer]
+ layer = model_executable.model.layers[i]
+
+ if not is_deepseek:
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
+ ops.reshape_and_cache_flash(
+ keys[i - model_executable.model.start_layer].to(
+ key_cache.device),
+ values[i - model_executable.model.start_layer].to(
+ value_cache.device),
+ key_cache,
+ value_cache,
+ slot_mapping[start_pos:end_pos],
+ layer.self_attn.attn.kv_cache_dtype,
+ layer.self_attn.attn._k_scale,
+ layer.self_attn.attn._v_scale,
+ )
+ else:
+ key_cache = kv_cache
+ copy_from =keys[i - model_executable.model.start_layer].to(
+ key_cache.device)
+ kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
+
+ hidden_or_intermediate_states_for_one_req.append(hidden)
+
+ if not bypass_model_exec:
+ # Some of the KV cache is not retrieved
+ # Here we will fall back to normal model forwarding
+ # But optionally you can adjust model_input so that you only do
+ # prefilling on those tokens that are missing KV caches.
+ logger.debug(
+ "[rank%d]: Failed to receive all KVs and hidden "
+ "states, redo model forwarding.", torch.distributed.get_rank())
+ hidden_or_intermediate_states = None
+
+ else:
+ logger.debug(
+ "[rank%d]: Successfully received all KVs and hidden "
+ "states, skip model forwarding.", torch.distributed.get_rank())
+ hidden_or_intermediate_states = torch.cat(
+ hidden_or_intermediate_states_for_one_req, dim=0)
+
+ return hidden_or_intermediate_states, bypass_model_exec, model_input
+
+ def close(self):
+ self.data_pipe.close()
+ # self.data_plane.close()
+
+ @staticmethod
+ def parse_request_id(request_id: str) -> Tuple[str, int]:
+ # Regular expression to match the string hostname and integer decode_kv_rank
+ pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)"
+
+ # Use re.search to find the pattern in the request_id
+ match = re.search(pattern, request_id)
+ if match:
+ # Extract the ranks
+ decode_hostname = match.group(1)
+ decode_rank = int(match.group(2))
+
+ return decode_hostname, decode_rank
+ raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank")
+
+ def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor):
+ remote_address = f"{hostname}:{port}"
+ self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address)
+ self.data_plane.send_tensor(values, f"{request_id}_values", remote_address)
+ self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address)
+
+ def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ keys = self.data_plane.recv_tensor(f"{request_id}_keys")
+ values = self.data_plane.recv_tensor(f"{request_id}_values")
+ hidden = self.data_plane.recv_tensor(f"{request_id}_hidden")
+ return keys, values, hidden
+
+ def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if kv_rank < config.kv_producers_parallel_size:
+ return kv_rank
+
+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
+
+
+ def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if kv_rank <= config.kv_producers_parallel_size:
+ return kv_rank * config.kv_producers_tensor_parallel_size + rank
+
+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank
+
+
+ def _get_data_plane_port(self, global_kv_rank: int) -> int:
+ return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank
+
+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
+ if rank == 0:
+ config_group = StatelessProcessGroup.create(
+ host=self.config.kv_ip,
+ port=self.config.kv_port,
+ rank=self.config.kv_rank,
+ world_size=self.config.kv_parallel_size,
+ )
+ parallel_configs = config_group.all_gather_obj({
+ "kv_role": self.config.kv_role,
+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size,
+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
+ })
+ logger.debug("parallel_configs: %s", parallel_configs)
+ kv_config_enhanced = {
+ "kv_producers_tensor_parallel_size": None,
+ "kv_consumers_tensor_parallel_size": None,
+ "kv_producers_pipeline_parallel_size": None,
+ "kv_consumers_pipeline_parallel_size": None,
+ "kv_producers_parallel_size": 0,
+ }
+ for parallel_config in parallel_configs:
+ kv_role = parallel_config["kv_role"]
+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
+
+ if kv_role == "kv_producer":
+ kv_config_enhanced["kv_producers_parallel_size"] += 1
+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
+ else:
+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
+ world_group.broadcast_object(kv_config_enhanced)
+ else:
+ kv_config_enhanced = world_group.broadcast_object()
+ logger.info("kv_config_enhanced: %s", kv_config_enhanced)
+
+ self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"]
+ self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"]
+ self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
index 5e1b6235..b4506877 100644 index 5e1b6235..b4506877 100644
--- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
...@@ -1361,7 +1015,7 @@ index 40589fb3..da2829cf 100644 ...@@ -1361,7 +1015,7 @@ index 40589fb3..da2829cf 100644
Returns: Returns:
diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
index 7aa53d07..f5dd50b7 100644 index 7aa53d07..db10f8a0 100644
--- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
+++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
@@ -45,33 +45,33 @@ class PyNcclPipe(KVPipeBase): @@ -45,33 +45,33 @@ class PyNcclPipe(KVPipeBase):
...@@ -1496,16 +1150,13 @@ index 7aa53d07..f5dd50b7 100644 ...@@ -1496,16 +1150,13 @@ index 7aa53d07..f5dd50b7 100644
if self.transport_thread is None: if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1) self.transport_thread = ThreadPoolExecutor(max_workers=1)
@@ -241,32 +243,39 @@ class PyNcclPipe(KVPipeBase): @@ -242,19 +244,23 @@ class PyNcclPipe(KVPipeBase):
with self.buffer_size_lock:
self.buffer_size += tensor_size self.buffer_size += tensor_size
- self.transport_thread.submit(self.send_tensor_wrapper, tensor, self.transport_thread.submit(self.send_tensor_wrapper, tensor,
- tensor_size) - tensor_size)
+ future = self.transport_thread.submit(self.send_tensor_wrapper, tensor,
+ tensor_size, + tensor_size,
+ target_rank) + target_rank)
+ return future
- def recv_tensor(self) -> Optional[torch.Tensor]: - def recv_tensor(self) -> Optional[torch.Tensor]:
+ def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: + def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
...@@ -1524,162 +1175,8 @@ index 7aa53d07..f5dd50b7 100644 ...@@ -1524,162 +1175,8 @@ index 7aa53d07..f5dd50b7 100644
- future = self.transport_thread.submit(self._recv_impl) - future = self.transport_thread.submit(self._recv_impl)
+ future = self.transport_thread.submit(self._recv_impl, src_rank) + future = self.transport_thread.submit(self._recv_impl, src_rank)
- try: try:
- tensor = future.result() tensor = future.result()
- except Exception as e:
- logger.error("Encountering exception in KV receiving thread")
- logger.error("%s", e)
- logger.error("My device: %s", self.device)
- import traceback
- traceback.print_exc()
- raise e
+ return future
+
+ # try:
+ # tensor = future.result()
+ # except Exception as e:
+ # logger.error("Encountering exception in KV receiving thread")
+ # logger.error("%s", e)
+ # logger.error("My device: %s", self.device)
+ # import traceback
+ # traceback.print_exc()
+ # raise e
- return tensor
+ # return tensor
def close(self):
"""
diff --git a/vllm/distributed/kv_transfer/kv_pipe/triton_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/triton_nccl_pipe.py
new file mode 100644
index 00000000..8a356504
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_pipe/triton_nccl_pipe.py
@@ -0,0 +1,124 @@
+import logging
+import threading
+import typing
+import zmq
+import socket
+import time
+import torch
+
+from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
+
+
+logger = logging.getLogger(__name__)
+
+
+class TritonNcclDataPlane:
+ def __init__(
+ self,
+ data_pipe: PyNcclPipe,
+ hostname: str = "",
+ port: int = 0,
+ ) -> None:
+
+ self.data_pipe = data_pipe
+ if not hostname:
+ hostname = socket.gethostname()
+ if port == 0:
+ raise ValueError("Port cannot be 0")
+ self._hostname = hostname
+ self._port = port
+ self.store = {}
+ self.context = zmq.Context()
+ self.rep_socket = self.context.socket(zmq.REP)
+ logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}")
+ self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
+ self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True)
+ self._listener_thread.start()
+ self.req_sockets = {}
+ logger.info(f"Rank {self.rank} connected to the server")
+
+ @property
+ def rank(self):
+ return self.data_pipe.kv_group_rank
+
+ def send_tensor(
+ self,
+ tensor: torch.Tensor,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ):
+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}")
+ return self._send_tensor(tensor, tensor_id, remote_address)
+
+ def recv_tensor(
+ self,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ) -> torch.Tensor:
+ ret = self._recv_tensor(tensor_id, remote_address)
+ return ret
+
+ def _send_tensor(
+ self,
+ tensor: torch.Tensor,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ):
+ logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}")
+ if remote_address is None:
+ self.store[tensor_id] = tensor
+ else:
+ # tensor_shape = "_".join(str(dim) for dim in tensor.shape)
+ # tensor_dtype = str(tensor.dtype)
+ if remote_address not in self.req_sockets:
+ self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
+ self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
+
+ req_socket = self.req_sockets[remote_address]
+ # req_socket.connect(f"tcp://{remote_address}")
+ req_socket.send_string(f"PUT {self.rank} {tensor_id}")
+ dst_rank = req_socket.recv_string()
+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}")
+ self.data_pipe.send_tensor(tensor, int(dst_rank))
+
+ def _recv_tensor(
+ self,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ) -> torch.Tensor:
+ logger.debug(f"Rank {self.rank} receiving tensor")
+ if remote_address is not None:
+ raise NotImplementedError("Getting tensor from remote rank not implemented")
+ if tensor_id in self.store:
+ logger.debug(f"Popping tensor {tensor_id} from store")
+ future = self.store.pop(tensor_id)
+ tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait
+ logger.debug(f"Rank {self.rank} received tensor")
+ return tensor
+
+ logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}")
+ time.sleep(0.001)
+ return self._recv_tensor(tensor_id, remote_address)
+ # raise NotImplementedError("Tensor not found in store")
+
+ def _receive_tensor(
+ self,
+ tensor_id: str,
+ rank: int,
+ ):
+ future = self.data_pipe.recv_tensor(rank)
+ logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store")
+ self.store[tensor_id] = future
+
+ def listen_for_requests(self):
+ while True:
+ cmd, rank, tensor_id = self.rep_socket.recv_string().split()
+ logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}")
+ self.rep_socket.send_string(f"{self.rank}")
+ if cmd == "GET":
+ raise NotImplementedError("Getting tensor from remote rank not implemented")
+ elif cmd == "PUT":
+ rank = int(rank)
+ # shape = [int(dim) for dim in shape.split("_")]
+ # dtype = getattr(torch, dtype)
+ self._receive_tensor(tensor_id, rank)
diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py
index 1e80e0bd..cd90206f 100644 index 1e80e0bd..cd90206f 100644
--- a/vllm/distributed/kv_transfer/kv_transfer_agent.py --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py
...@@ -1728,20 +1225,255 @@ index d82d9ad9..542ccfe8 100644 ...@@ -1728,20 +1225,255 @@ index d82d9ad9..542ccfe8 100644
self.parallel_config.pipeline_parallel_size, self.parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id] self.async_callbacks[v_id]
if self.model_config.use_async_output_proc else None) if self.model_config.use_async_output_proc else None)
diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py
index 3cf1850e..38acca0e 100644
--- a/vllm/engine/multiprocessing/__init__.py
+++ b/vllm/engine/multiprocessing/__init__.py
@@ -21,6 +21,7 @@ IPC_INPUT_EXT = "_input_socket"
IPC_OUTPUT_EXT = "_output_socket"
IPC_HEALTH_EXT = "_health_socket"
IPC_DATA_EXT = "_data_socket"
+IPC_METRICS_EXT = "_metrics_socket"
class MQEngineDeadError(RuntimeError):
@@ -157,3 +158,10 @@ def ENGINE_DEAD_ERROR(
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
f"find the original error: {repr(error)}.")
+
+@dataclass
+class KvMetrics:
+ request_active_slots: int
+ request_total_slots: int
+ kv_active_blocks: int
+ kv_total_blocks: int
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..6a7ea3ae 100644
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -25,14 +25,15 @@ from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
- IPC_OUTPUT_EXT, RPC_REQUEST_T,
+ IPC_OUTPUT_EXT, IPC_METRICS_EXT,
+ RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
- RPCUProfileRequest)
+ RPCUProfileRequest, KvMetrics)
from vllm.engine.protocol import EngineClient
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
@@ -115,6 +116,10 @@ class MQLLMEngineClient(EngineClient):
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
+ # Metrics.
+ self.metrics_socket: Socket = self.context.socket(zmq.constants.PULL)
+ self.metrics_socket.connect(f"{ipc_path}{IPC_METRICS_EXT}")
+
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@@ -129,6 +134,12 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
+
+ # Loop to check metrics of the LLMEngine periodically.
+ # Started after the MQLLMEngine is ready.
+ self.metrics_loop: Optional[asyncio.Task] = None
+ self.metrics_publisher = None
+
self._engine_process = psutil.Process(engine_pid)
@staticmethod
@@ -180,6 +191,46 @@ class MQLLMEngineClient(EngineClient):
except Exception as e:
self._set_errored(e)
+ async def run_metrics_loop(self, timeout: int):
+ """Background loop that continually checks to ensure the engine process
+ is still alive.
+ """
+ try:
+ while True:
+ # Check if the engine process is running:
+ if not self._engine_process.is_running() or (
+ self._engine_process.status() == psutil.STATUS_ZOMBIE):
+ # NB: is_running() returns True for zombies
+ self._set_errored(
+ RuntimeError(
+ f"Engine process (pid {self._engine_process.pid}) "
+ "died."))
+ break
+
+ if await self.metrics_socket.poll(timeout=timeout):
+ # Metrics received- check the message
+ message: Frame = await self.metrics_socket.recv(copy=False)
+ kv_metrics = pickle.loads(message.buffer)
+ if self.metrics_publisher is not None:
+ if isinstance(kv_metrics, KvMetrics):
+ self.metrics_publisher.publish(kv_metrics.request_active_slots,
+ kv_metrics.request_total_slots,
+ kv_metrics.kv_active_blocks,
+ kv_metrics.kv_total_blocks)
+
+ logger.debug("Metircs successful.")
+
+ except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
+
+ except psutil.NoSuchProcess:
+ self._set_errored(
+ RuntimeError(
+ f"Engine process (pid {self._engine_process.pid}) died."))
+
+ except Exception as e:
+ self._set_errored(e)
+
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
@@ -284,6 +335,12 @@ class MQLLMEngineClient(EngineClient):
if self.health_loop is None:
self.health_loop = asyncio.create_task(
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
+
+ # Start metrics_loop.
+ if self.metrics_loop is None:
+ self.metrics_loop = asyncio.create_task(
+ self.run_metrics_loop(timeout=VLLM_RPC_TIMEOUT))
+
def close(self):
"""Destroy the ZeroMQ Context."""
@@ -293,6 +350,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
+ if self.metrics_loop is not None:
+ self.metrics_loop.cancel()
if self.output_loop is not None:
self.output_loop.cancel()
@@ -705,3 +764,6 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output
+
+ def set_metrics_publisher(self, metrics_publisher):
+ self.metrics_publisher = metrics_publisher
diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py
index a0dd7958..dc6ea25d 100644
--- a/vllm/engine/multiprocessing/engine.py
+++ b/vllm/engine/multiprocessing/engine.py
@@ -14,24 +14,56 @@ from vllm.engine.llm_engine import LLMEngine
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
- IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
+ IPC_OUTPUT_EXT, IPC_METRICS_EXT,
+ REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
- RPCUProfileRequest)
+ RPCUProfileRequest, KvMetrics)
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
+from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo
+from dataclasses import dataclass, field
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
+class KvStatLogger(StatLoggerBase):
+ def __init__(
+ self,
+ max_num_seqs: int,
+ num_total_gpu_blocks: int,
+ metrics_socket
+ ):
+ # Must query initialized scheduler for max infos
+ self.request_total_slots = max_num_seqs
+ self.kv_total_blocks = num_total_gpu_blocks
+ self.metrics_socket = metrics_socket
+
+ # KV metrics
+ self._send_kv_metrics(0, 0)
+
+ def log(self, stats: Stats) -> None:
+ self._send_kv_metrics(
+ stats.num_running_sys,
+ int(stats.gpu_cache_usage_sys * self.kv_total_blocks)
+ )
+
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+ pass
+
+ def _send_kv_metrics(self, active_slots, active_kv_blocks):
+ if not self.metrics_socket.closed:
+ metrics_bytes = pickle.dumps(KvMetrics(active_slots, self.request_total_slots, active_kv_blocks, self.kv_total_blocks))
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`.
@@ -94,12 +126,24 @@ class MQLLMEngine:
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
+ # Send metrics back to client.
+ self.metrics_socket = self.ctx.socket(zmq.constants.PUSH)
+ self.metrics_socket.bind(f"{ipc_path}{IPC_METRICS_EXT}")
+
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Error state.
self._errored_with: Optional[BaseException] = None
+ # Attach logger for continuous metrics publishing
+ self.stat_logger = KvStatLogger(
+ self.engine.scheduler_config.max_num_seqs,
+ self.engine.cache_config.num_gpu_blocks,
+ self.metrics_socket
+ )
+ self.engine.add_logger("kv_metrics", self.stat_logger)
+
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
diff --git a/vllm/envs.py b/vllm/envs.py diff --git a/vllm/envs.py b/vllm/envs.py
index 745b068b..438142e3 100644 index 745b068b..0ae63d9b 100644
--- a/vllm/envs.py --- a/vllm/envs.py
+++ b/vllm/envs.py +++ b/vllm/envs.py
@@ -87,6 +87,8 @@ if TYPE_CHECKING: @@ -87,6 +87,10 @@ if TYPE_CHECKING:
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_RAY_BUNDLE_INDICES: str = ""
+ VLLM_KV_CAPI_PATH: Optional[str] = None + VLLM_KV_CAPI_PATH: Optional[str] = None
+ VLLM_WORKER_ID: Optional[str] = None + VLLM_KV_NAMESPACE: Optional[str] = None
+ VLLM_KV_COMPONENT: Optional[str] = None
+ VLLM_WORKER_ID: Optional[int] = None
def get_default_cache_root(): def get_default_cache_root():
@@ -572,6 +574,14 @@ environment_variables: Dict[str, Callable[[], Any]] = { @@ -572,6 +576,21 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# models the alignment is already naturally aligned to 256 bytes. # models the alignment is already naturally aligned to 256 bytes.
"VLLM_CUDA_MEM_ALIGN_KV_CACHE": "VLLM_CUDA_MEM_ALIGN_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))), lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
...@@ -1750,9 +1482,16 @@ index 745b068b..438142e3 100644 ...@@ -1750,9 +1482,16 @@ index 745b068b..438142e3 100644
+ "VLLM_KV_CAPI_PATH": + "VLLM_KV_CAPI_PATH":
+ lambda: os.environ.get("VLLM_KV_CAPI_PATH", None), + lambda: os.environ.get("VLLM_KV_CAPI_PATH", None),
+ +
+ # Identifiers to publish KV related information
+ "VLLM_KV_NAMESPACE":
+ lambda: os.environ.get("VLLM_KV_NAMESPACE", None),
+ "VLLM_KV_COMPONENT":
+ lambda: os.environ.get("VLLM_KV_COMPONENT", None),
+
+ # Worker ID used for identifying workers in distributed settings + # Worker ID used for identifying workers in distributed settings
+ "VLLM_WORKER_ID": + "VLLM_WORKER_ID":
+ lambda: os.getenv("VLLM_WORKER_ID", None), + lambda: int(os.getenv("VLLM_WORKER_ID", "0"))
+ if "VLLM_WORKER_ID" in os.environ else None,
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -180,7 +180,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>] ...@@ -180,7 +180,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>]
Example: Example:
```bash ```bash
# Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model # Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model
bash /workspace/examples/python_rs/llm/vllm/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B bash /workspace/examples/python_rs/llm/vllm/scripts/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# List tmux sessions # List tmux sessions
tmux ls tmux ls
...@@ -252,7 +252,7 @@ llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B triton-init ...@@ -252,7 +252,7 @@ llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B triton-init
``` ```
```bash ```bash
curl localhost:9992/v1/chat/completions -H "Content-Type: application/json" -d '{ curl localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [ "messages": [
{ {
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uvloop
from common.protocol import Request, Response
from vllm.logger import logger as vllm_logger
from triton_distributed.llm import KvRouter
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
router,
workers_client,
):
self.router = router
self.workers_client = workers_client
@triton_endpoint(Request, Response)
async def generate(self, request):
lora_id = 0
worker_id = None
tokens = [3] * 64
try:
worker_id = await self.router.schedule(tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"got exception of type {type(e)}: {e}")
worker_id = None
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
if worker_id is None:
vllm_logger.info("randomly select worker")
engine_generator = await self.workers_client.random(
request.model_dump_json()
)
else:
vllm_logger.info(f"directly select worker: {worker_id}")
engine_generator = await self.workers_client.direct(
request.model_dump_json(), worker_id
)
async for resp in engine_generator:
resp = resp.data() if hasattr(resp, "data") else resp
yield resp
@triton_endpoint(Request, Response)
async def mock_generate(self, request):
print(f"Received request: {request}")
yield "Hello, World!"
ROUTE_SELF = True
@triton_worker()
async def worker(runtime: DistributedRuntime):
workers_client = (
await runtime.namespace("triton-init")
.component("vllm")
.endpoint("generate")
.client()
)
vllm_logger.info(
f"Have number of workers ({len(workers_client.endpoint_ids())}) are ready:\n"
+ "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
)
# [TODO] Collect endpoint implementation expects services to provide
# ForwardPassMetrics as part of stats handling and it will panic if
# otherwise. This needs to be fixed so that non-providing endpoints will
# simply be ignored, but before that, we will make sure that the services
# of the same namespace::component are created via KvMetricsPublisher,
# if it is also used to create endpoints.
kv_listener = runtime.namespace("triton-init").component("vllm")
await kv_listener.create_service()
router = KvRouter(runtime, kv_listener)
# i.e. below will cause panic
# endpoint = kv_listener.endpoint("generate")
# await endpoint.serve_endpoint(
# Router(router, workers_client).mock_generate
# )
router_component = runtime.namespace("triton-init").component("frontend")
await router_component.create_service()
endpoint = router_component.endpoint("generate")
await endpoint.serve_endpoint(Router(router, workers_client).generate)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import ctypes
from ctypes import c_char_p, c_int64, c_uint32
import uvloop
from common.protocol import Request, Response
from vllm.logger import logger as vllm_logger
from triton_distributed.llm import KvMetricsPublisher
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
class TritonResult:
OK = 0
ERR = 1
class MockEngine:
"""
Request handler for the generate endpoint
"""
def __init__(self, metrics_publisher, worker_id):
self.worker_id = worker_id
# KV events
self.lib = ctypes.CDLL("/opt/triton/llm_binding/lib/libtriton_llm_capi.so")
self.lib.triton_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
self.lib.triton_llm_init.restype = c_uint32
result = self.lib.triton_llm_init(
"triton-init".encode(), "vllm".encode(), worker_id
)
if result == TritonResult.OK:
vllm_logger.info(
"KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
)
else:
vllm_logger.info("KVCacheEventManager initialization failed!")
self.lib.triton_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids
ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
ctypes.POINTER(ctypes.c_uint64), # parent_hash
ctypes.c_uint64, # lora_id
]
self.lib.triton_kv_event_publish_stored.restype = (
ctypes.c_uint32
) # triton_llm_result_t
self.lib.triton_kv_event_publish_removed.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
]
self.lib.triton_kv_event_publish_removed.restype = (
ctypes.c_uint32
) # triton_llm_result_t
# KV metrics
self.metrics_publisher = metrics_publisher
self.request_active_slots = 0
self.request_total_slots = 4
self.kv_active_block = 0
self.kv_total_blocks = 4
# [NOTE] Now that the component must has proper metrics reported
# to be properly selected by the router
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
self.event_id_counter = 0
self.tokens = [3] * 64
@triton_endpoint(Request, Response)
async def generate(self, request):
print(f"Received request: {request}")
self.request_active_slots = min(
self.request_active_slots + 1, self.request_total_slots
)
self.kv_active_block = min(self.kv_active_block + 1, self.kv_total_blocks)
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
self.store_event()
yield "Hello, World!"
def store_event(self):
parent_hash = (
(ctypes.c_uint64 * 1)(self.event_id_counter)
if self.event_id_counter > 0
else None
)
result = self.lib.triton_kv_event_publish_stored(
self.event_id_counter, # uint64_t event_id
(ctypes.c_uint32 * len(self.tokens))(
*self.tokens
), # const uint32_t *token_ids
(ctypes.c_size_t * 1)(
len(self.tokens)
), # const uintptr_t *num_block_tokens
(ctypes.c_uint64 * 1)(self.event_id_counter), # const uint64_t *block_ids
1, # uintptr_t num_blocks
parent_hash, # const uint64_t *parent_hash
0, # uint64_t lora_id
)
self.event_id_counter += 1
if result == TritonResult.OK:
vllm_logger.debug(f"Store - Published KV Event: {self.event_id_counter}")
else:
vllm_logger.debug(
f"Store - Failed to Publish KV Event: {self.event_id_counter}"
)
async def cooldown(self):
while True:
await asyncio.sleep(5)
self.request_active_slots = max(0, self.request_active_slots - 1)
self.kv_active_block = max(0, self.kv_active_block - 1)
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
@triton_worker()
async def worker(runtime: DistributedRuntime):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("triton-init").component("vllm")
metrics_publisher = KvMetricsPublisher()
await metrics_publisher.create_service(component)
endpoint = component.endpoint("generate")
engine = MockEngine(metrics_publisher, endpoint.lease_id())
await asyncio.gather(
engine.cooldown(),
endpoint.serve_endpoint(engine.generate),
)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
...@@ -31,8 +31,8 @@ from vllm.logger import logger as vllm_logger ...@@ -31,8 +31,8 @@ from vllm.logger import logger as vllm_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from triton_distributed._core import Client
from triton_distributed.runtime import ( from triton_distributed.runtime import (
Client,
DistributedRuntime, DistributedRuntime,
triton_endpoint, triton_endpoint,
triton_worker, triton_worker,
...@@ -126,7 +126,7 @@ class Processor(ProcessMixIn): ...@@ -126,7 +126,7 @@ class Processor(ProcessMixIn):
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
).model_dump_json(), ).model_dump_json(),
uuid.UUID(worker_id).int, int(worker_id),
) )
output = self.generate_responses(engine_generator) output = self.generate_responses(engine_generator)
......
...@@ -58,20 +58,21 @@ class Router: ...@@ -58,20 +58,21 @@ class Router:
@triton_endpoint(Tokens, WorkerId) @triton_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]: async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0 lora_id = 0
worker_id = "" worker_id = None
if self.routing_strategy == RoutingStrategy.PREFIX: if self.routing_strategy == RoutingStrategy.PREFIX:
try: try:
worker_id = await self.router.schedule(request.tokens, lora_id) worker_id = await self.router.schedule(request.tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e: except Exception as e:
vllm_logger.info(f"{e}") vllm_logger.info(f"{e}")
if "No worker found" in str(e): worker_id = None
worker_id = "" vllm_logger.exception(f"Error during worker selection: {e}")
else:
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}") vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
yield worker_id yield str(worker_id)
else: else:
# TODO: Do we implement round_robin and random here? # TODO: Do we implement round_robin and random here?
...@@ -113,8 +114,7 @@ async def worker(runtime: DistributedRuntime, args: Namespace): ...@@ -113,8 +114,7 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
+ "\n".join(f"id: {id}" for id in workers_client.endpoint_ids()) + "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
) )
# TODO Router is a fixed namespace separate from the others kv_listener = runtime.namespace("triton-init").component("vllm")
kv_listener = runtime.namespace("router").component(args.model_name)
await kv_listener.create_service() await kv_listener.create_service()
router_component = runtime.namespace("triton-init").component("router") router_component = runtime.namespace("triton-init").component("router")
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import asyncio import asyncio
import os import os
import uuid
from typing import AsyncIterator from typing import AsyncIterator
import uvloop import uvloop
...@@ -26,6 +25,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -26,6 +25,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import logger as vllm_logger from vllm.logger import logger as vllm_logger
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from triton_distributed.llm import KvMetricsPublisher
from triton_distributed.runtime import ( from triton_distributed.runtime import (
DistributedRuntime, DistributedRuntime,
triton_endpoint, triton_endpoint,
...@@ -40,10 +40,18 @@ class VllmEngine(BaseVllmEngine): ...@@ -40,10 +40,18 @@ class VllmEngine(BaseVllmEngine):
vLLM Inference Engine vLLM Inference Engine
""" """
def __init__(self, engine_args: AsyncEngineArgs): def __init__(
self, engine_args: AsyncEngineArgs, metrics_publisher: KvMetricsPublisher
):
self.metrics_publisher = metrics_publisher
self.engine_args = engine_args self.engine_args = engine_args
super().__init__(engine_args) super().__init__(engine_args)
async def initialize(self):
await super().initialize()
assert self.engine_client is not None, "engine_client was not initialized"
self.engine_client.set_metrics_publisher(self.metrics_publisher)
@triton_endpoint(vLLMGenerateRequest, MyRequestOutput) @triton_endpoint(vLLMGenerateRequest, MyRequestOutput)
async def generate(self, request) -> AsyncIterator: async def generate(self, request) -> AsyncIterator:
assert ( assert (
...@@ -74,21 +82,32 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -74,21 +82,32 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
""" """
Serve the triton-init.vllm.generate endpoint. Serve the triton-init.vllm.generate endpoint.
""" """
metrics_publisher = KvMetricsPublisher()
worker_component = runtime.namespace("triton-init").component("vllm") worker_component = runtime.namespace("triton-init").component("vllm")
await worker_component.create_service() await metrics_publisher.create_service(worker_component)
worker_endpoint = worker_component.endpoint("generate") worker_endpoint = worker_component.endpoint("generate")
# KV Publisher and Aggregator requires a UUID (str) VLLM_WORKER_ID = worker_endpoint.lease_id()
# KV Router requires a lease_id (int)
# This allows us to please both, until they are unified
# If VLLM_WORKER_ID is not set, KV Routing will fail
VLLM_WORKER_ID = uuid.UUID(int=worker_endpoint.lease_id())
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID) os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}") vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
vllm_engine = VllmEngine(engine_args) VLLM_KV_NAMESPACE = "triton-init"
os.environ["VLLM_KV_NAMESPACE"] = str(VLLM_KV_NAMESPACE)
VLLM_KV_COMPONENT = "vllm"
os.environ["VLLM_KV_COMPONENT"] = str(VLLM_KV_COMPONENT)
vllm_engine = VllmEngine(engine_args, metrics_publisher)
await vllm_engine.initialize() await vllm_engine.initialize()
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
metrics_publisher.publish(
0,
1024,
0,
1024,
)
await worker_endpoint.serve_endpoint(vllm_engine.generate) await worker_endpoint.serve_endpoint(vllm_engine.generate)
......
...@@ -22,13 +22,13 @@ use tracing as log; ...@@ -22,13 +22,13 @@ use tracing as log;
use uuid::Uuid; use uuid::Uuid;
use triton_distributed_llm::kv_router::{ use triton_distributed_llm::kv_router::{
indexer::compute_block_hash_for_seq, protocols::*, publisher::KvPublisher, indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher,
}; };
use triton_distributed_runtime::{DistributedRuntime, Worker}; use triton_distributed_runtime::{DistributedRuntime, Worker};
static WK: OnceCell<Worker> = OnceCell::new(); static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new(); static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls? // [FIXME] shouldn't the publisher be instance passing between API calls?
static KV_PUB: OnceCell<KvPublisher> = OnceCell::new(); static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
fn initialize_tracing() { fn initialize_tracing() {
// Sets up RUST_LOG environment variable for logging while KV Publishing // Sets up RUST_LOG environment variable for logging while KV Publishing
...@@ -49,11 +49,12 @@ pub enum TritonLlmResult { ...@@ -49,11 +49,12 @@ pub enum TritonLlmResult {
} }
/// # Safety /// # Safety
/// the model_name_c_str and worker_id_c_str are passed as pointers to C strings /// the namespace_c_str and component_c_str are passed as pointers to C strings
#[no_mangle] #[no_mangle]
pub unsafe extern "C" fn triton_llm_init( pub unsafe extern "C" fn triton_llm_init(
model_name_c_str: *const c_char, namespace_c_str: *const c_char,
worker_id_c_str: *const c_char, component_c_str: *const c_char,
worker_id: i64,
) -> TritonLlmResult { ) -> TritonLlmResult {
initialize_tracing(); initialize_tracing();
let wk = match WK.get_or_try_init(Worker::from_settings) { let wk = match WK.get_or_try_init(Worker::from_settings) {
...@@ -78,7 +79,7 @@ pub unsafe extern "C" fn triton_llm_init( ...@@ -78,7 +79,7 @@ pub unsafe extern "C" fn triton_llm_init(
} }
} }
}); });
let model_name = match unsafe { CStr::from_ptr(model_name_c_str) }.to_str() { let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
Ok(s) => s.to_string(), Ok(s) => s.to_string(),
Err(e) => { Err(e) => {
eprintln!("Failed to convert C string to Rust string: {:?}", e); eprintln!("Failed to convert C string to Rust string: {:?}", e);
...@@ -86,24 +87,17 @@ pub unsafe extern "C" fn triton_llm_init( ...@@ -86,24 +87,17 @@ pub unsafe extern "C" fn triton_llm_init(
} }
}; };
let worker_id_str = match unsafe { CStr::from_ptr(worker_id_c_str) }.to_str() { let component = match unsafe { CStr::from_ptr(component_c_str) }.to_str() {
Ok(s) => s, Ok(s) => s.to_string(),
Err(e) => { Err(e) => {
eprintln!("Failed to convert C string to Rust string: {:?}", e); eprintln!("Failed to convert C string to Rust string: {:?}", e);
return TritonLlmResult::ERR; return TritonLlmResult::ERR;
} }
}; };
let worker_id_uuid = match Uuid::parse_str(worker_id_str) {
Ok(uuid) => uuid,
Err(e) => {
eprintln!("Failed to parse worker_id as UUID: {:?}", e);
return TritonLlmResult::ERR;
}
};
match result { match result {
Ok(_) => match KV_PUB Ok(_) => match KV_PUB
.get_or_try_init(move || triton_create_kv_publisher(model_name, worker_id_uuid)) .get_or_try_init(move || triton_create_kv_publisher(namespace, component, worker_id))
{ {
Ok(_) => TritonLlmResult::OK, Ok(_) => TritonLlmResult::OK,
Err(e) => { Err(e) => {
...@@ -143,17 +137,18 @@ pub extern "C" fn triton_llm_load_publisher_create() -> TritonLlmResult { ...@@ -143,17 +137,18 @@ pub extern "C" fn triton_llm_load_publisher_create() -> TritonLlmResult {
// c++ executor api // c++ executor api
fn triton_create_kv_publisher( fn triton_create_kv_publisher(
model_name: String, namespace: String,
worker_id: Uuid, component: String,
) -> Result<KvPublisher, anyhow::Error> { worker_id: i64,
log::info!("Creating KV Publisher for model: {}", model_name); ) -> Result<KvEventPublisher, anyhow::Error> {
log::info!("Creating KV Publisher for model: {}", component);
match DRT match DRT
.get() .get()
.ok_or(anyhow::Error::msg("Could not get Distributed Runtime")) .ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
{ {
Ok(drt) => { Ok(drt) => {
let backend = drt.namespace("router")?.component(model_name)?; let backend = drt.namespace(namespace)?.component(component)?;
KvPublisher::new(drt.clone(), backend, worker_id) KvEventPublisher::new(drt.clone(), backend, worker_id)
} }
Err(e) => Err(e), Err(e) => Err(e),
} }
......
...@@ -64,6 +64,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -64,6 +64,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Client>()?; m.add_class::<Client>()?;
m.add_class::<AsyncResponseStream>()?; m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::kv::KvRouter>()?; m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::kv::KvMetricsPublisher>()?;
engine::add_to_module(m)?; engine::add_to_module(m)?;
......
...@@ -23,6 +23,7 @@ pub(crate) struct KvRouter { ...@@ -23,6 +23,7 @@ pub(crate) struct KvRouter {
#[pymethods] #[pymethods]
impl KvRouter { impl KvRouter {
#[new] #[new]
// [FXIME] 'drt' can be obtained from 'component'
fn new(drt: DistributedRuntime, component: Component) -> PyResult<Self> { fn new(drt: DistributedRuntime, component: Component) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime(); let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async { runtime.block_on(async {
...@@ -44,11 +45,64 @@ impl KvRouter { ...@@ -44,11 +45,64 @@ impl KvRouter {
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let router = self.inner.clone(); let router = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let uuid = router let worker_id = router
.schedule(&token_ids, lora_id) .schedule(&token_ids, lora_id)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(uuid.to_string()) Ok(worker_id)
}) })
} }
} }
#[pyclass]
pub(crate) struct KvMetricsPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvMetricsPublisher>,
}
#[pymethods]
impl KvMetricsPublisher {
#[new]
fn new() -> PyResult<Self> {
let inner = llm_rs::kv_router::publisher::KvMetricsPublisher::new().map_err(to_pyerr)?;
Ok(Self {
inner: inner.into(),
})
}
fn create_service<'p>(
&self,
py: Python<'p>,
component: Component,
) -> PyResult<Bound<'p, PyAny>> {
let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let _ = rs_publisher
.create_service(rs_component)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
fn publish<'p>(
&self,
py: Python<'p>,
request_active_slots: u64,
request_total_slots: u64,
kv_active_blocks: u64,
kv_total_blocks: u64,
) -> PyResult<()> {
self.inner
.publish(
llm_rs::kv_router::protocols::ForwardPassMetrics {
request_active_slots,
request_total_slots,
kv_active_blocks,
kv_total_blocks,
}
.into(),
)
.map_err(to_pyerr)
}
}
...@@ -128,7 +128,7 @@ class Client: ...@@ -128,7 +128,7 @@ class Client:
class KvRouter: class KvRouter:
""" """
The runtime object for a distributed NOVA applications A router will determine which worker should handle a given request.
""" """
... ...
...@@ -138,9 +138,36 @@ class KvRouter: ...@@ -138,9 +138,36 @@ class KvRouter:
Create a `KvRouter` object that is associated with the `component` Create a `KvRouter` object that is associated with the `component`
""" """
def schedule(self, token_ids: List[int], lora_id: int) -> str: def schedule(self, token_ids: List[int], lora_id: int) -> int:
""" """
Return the worker id that should handle the given token ids, Return the worker id that should handle the given token ids,
exception will be raised if there is no worker available. exception will be raised if there is no worker available.
""" """
... ...
class KvMetricsPublisher:
"""
A metrics publisher will provide KV metrics to the router.
"""
...
def __init__(self) -> None:
"""
Create a `KvMetricsPublisher` object
"""
def create_service(self, component: Component) -> None:
"""
Similar to Component.create_service, but only service created through
this method will interact with KV router of the same component.
"""
def publish(self, request_active_slots: int,
request_total_slots: int,
kv_active_blocks: int,
kv_total_blocks: int) -> None:
"""
Update the KV metrics being reported.
"""
...
...@@ -13,4 +13,5 @@ ...@@ -13,4 +13,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from triton_distributed._core import KvMetricsPublisher as KvMetricsPublisher
from triton_distributed._core import KvRouter as KvRouter from triton_distributed._core import KvRouter as KvRouter
...@@ -20,7 +20,8 @@ from typing import Any, AsyncGenerator, Callable, Type ...@@ -20,7 +20,8 @@ from typing import Any, AsyncGenerator, Callable, Type
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from triton_distributed._core import DistributedRuntime from triton_distributed._core import Client as Client
from triton_distributed._core import DistributedRuntime as DistributedRuntime
def triton_worker(): def triton_worker():
......
...@@ -23,14 +23,11 @@ use triton_distributed_runtime::{component::Component, DistributedRuntime}; ...@@ -23,14 +23,11 @@ use triton_distributed_runtime::{component::Component, DistributedRuntime};
pub mod indexer; pub mod indexer;
pub mod protocols; pub mod protocols;
pub mod publisher; pub mod publisher;
// [WIP] enable service_builder() through worker for metrics reporting
// pub mod worker;
mod scheduler; mod scheduler;
mod scoring; mod scoring;
use crate::kv_router::{ use crate::kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
protocols::KV_BLOCK_SIZE,
scheduler::{Endpoint, KvScheduler, Service}, scheduler::{Endpoint, KvScheduler, Service},
scoring::ProcessedEndpoints, scoring::ProcessedEndpoints,
}; };
...@@ -113,7 +110,7 @@ impl KvRouter { ...@@ -113,7 +110,7 @@ impl KvRouter {
} }
// [TODO] indexer needs to take 'lora_id' as parameter // [TODO] indexer needs to take 'lora_id' as parameter
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<String> { pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
// Extracting part of the code in KvRouter::generate() for only // Extracting part of the code in KvRouter::generate() for only
// the decision making part, routing is done by the caller // the decision making part, routing is done by the caller
let isl_tokens = token_ids.len(); let isl_tokens = token_ids.len();
...@@ -122,25 +119,8 @@ impl KvRouter { ...@@ -122,25 +119,8 @@ impl KvRouter {
.find_matches_for_request(token_ids.as_slice()) .find_matches_for_request(token_ids.as_slice())
.await?; .await?;
log::debug!("KV router overlap_scores: {:?}", overlap_scores); log::debug!("KV router overlap_scores: {:?}", overlap_scores);
// [FIXME] Python binding results in "endpoint subscriber shutdown" error, let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
// need to investigate whether it happens in pure rust as well and then Ok(worker_id)
// root cause it. Before that, not doing intelligent scheduling for rapid
// development..
// [FIXME] also need to fix that scheduler returns worker subject which is not
// the same as worker id (uuid). Seems like it adds additional annotation on top of uuid.
// Need to double check
// 'worker_subject' should be the same as worker id used for direct routing
// let worker_subject = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
let mut selected_worker_subject = Option::<String>::None;
for (worker_subject, overlap_score) in &overlap_scores.scores {
if ((*overlap_score as usize * KV_BLOCK_SIZE) as f64 / isl_tokens as f64) >= 0.5 {
selected_worker_subject = Some(worker_subject.to_string());
}
}
match selected_worker_subject {
None => Err(anyhow::anyhow!("No worker found")),
Some(worker_subject) => Ok(worker_subject),
}
} }
} }
...@@ -167,7 +147,7 @@ async fn collect_endpoints( ...@@ -167,7 +147,7 @@ async fn collect_endpoints(
.unwrap(); .unwrap();
// [FIXME] Endpoint is parsed from nats stats handler which may not include 'data' field // [FIXME] Endpoint is parsed from nats stats handler which may not include 'data' field
// if the service hasn't registered the handler. // if the service hasn't registered the handler. Need to be tolerant to this.
// Another option is to make sure the router is configured properly that // Another option is to make sure the router is configured properly that
// it listens to the right subject (where other publisher has stats). // it listens to the right subject (where other publisher has stats).
let services: Vec<Service> = values let services: Vec<Service> = values
......
...@@ -79,7 +79,7 @@ pub enum KvRouterError { ...@@ -79,7 +79,7 @@ pub enum KvRouterError {
} }
/// Identifier of a LLM worker which emits events to the router. /// Identifier of a LLM worker which emits events to the router.
pub type WorkerId = uuid::Uuid; pub type WorkerId = i64;
/// A shared reference to a [`RadixBlock`]. /// A shared reference to a [`RadixBlock`].
type SharedRadixBlock = Rc<RefCell<RadixBlock>>; type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
......
...@@ -13,20 +13,20 @@ ...@@ -13,20 +13,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::kv_router::{indexer::RouterEvent, protocols::KvCacheEvent, KV_EVENT_SUBJECT}; use crate::kv_router::{indexer::RouterEvent, protocols::*, KV_EVENT_SUBJECT};
use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing as log; use tracing as log;
use triton_distributed_runtime::{component::Component, DistributedRuntime, Result}; use triton_distributed_runtime::{component::Component, DistributedRuntime, Result};
use uuid::Uuid;
pub struct KvPublisher { pub struct KvEventPublisher {
tx: mpsc::UnboundedSender<KvCacheEvent>, tx: mpsc::UnboundedSender<KvCacheEvent>,
} }
impl KvPublisher { impl KvEventPublisher {
pub fn new(drt: DistributedRuntime, backend: Component, worker_id: Uuid) -> Result<Self> { pub fn new(drt: DistributedRuntime, backend: Component, worker_id: i64) -> Result<Self> {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let p = KvPublisher { tx }; let p = KvEventPublisher { tx };
start_publish_task(drt, backend, worker_id, rx); start_publish_task(drt, backend, worker_id, rx);
Ok(p) Ok(p)
...@@ -41,12 +41,10 @@ impl KvPublisher { ...@@ -41,12 +41,10 @@ impl KvPublisher {
fn start_publish_task( fn start_publish_task(
drt: DistributedRuntime, drt: DistributedRuntime,
backend: Component, backend: Component,
worker_id: Uuid, worker_id: i64,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>, mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
) { ) {
let client = drt.nats_client().client().clone(); let client = drt.nats_client().client().clone();
// [FIXME] service name is for metrics polling?
// let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT); let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
log::info!("Publishing KV Events to subject: {}", kv_subject); log::info!("Publishing KV Events to subject: {}", kv_subject);
...@@ -61,3 +59,37 @@ fn start_publish_task( ...@@ -61,3 +59,37 @@ fn start_publish_task(
} }
}); });
} }
pub struct KvMetricsPublisher {
tx: tokio::sync::watch::Sender<Arc<ForwardPassMetrics>>,
rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>,
}
impl KvMetricsPublisher {
pub fn new() -> Result<Self> {
let (tx, rx) = tokio::sync::watch::channel(Arc::new(ForwardPassMetrics::default()));
Ok(KvMetricsPublisher { tx, rx })
}
pub fn publish(
&self,
metrics: Arc<ForwardPassMetrics>,
) -> Result<(), tokio::sync::watch::error::SendError<Arc<ForwardPassMetrics>>> {
log::debug!("Publish metrics: {:?}", metrics);
self.tx.send(metrics)
}
pub async fn create_service(&self, component: Component) -> Result<()> {
let mut metrics_rx = self.rx.clone();
let _ = component
.service_builder()
.stats_handler(Some(Box::new(move |name, stats| {
log::debug!("[IN worker?] Stats for service {}: {:?}", name, stats);
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})))
.create()
.await?;
Ok(())
}
}
...@@ -17,8 +17,6 @@ use serde::{Deserialize, Serialize}; ...@@ -17,8 +17,6 @@ use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use std::cmp::min; use std::cmp::min;
use uuid::Uuid;
use crate::kv_router::indexer::OverlapScores; use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE}; pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE};
use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::scoring::ProcessedEndpoints;
...@@ -44,16 +42,17 @@ pub struct Endpoint { ...@@ -44,16 +42,17 @@ pub struct Endpoint {
} }
impl Endpoint { impl Endpoint {
pub fn worker_id(&self) -> Uuid { pub fn worker_id(&self) -> i64 {
Uuid::parse_str( i64::from_str_radix(
self.subject self.subject
.split(".") .split("-")
.last() .last()
.expect("invalid subject") .expect("invalid subject")
.to_string() .to_string()
.as_str(), .as_str(),
16,
) )
.expect("invalid uuid") .expect("invalid worker id")
} }
} }
...@@ -69,11 +68,11 @@ pub struct Service { ...@@ -69,11 +68,11 @@ pub struct Service {
pub struct SchedulingRequest { pub struct SchedulingRequest {
isl_tokens: usize, isl_tokens: usize,
overlap: OverlapScores, overlap: OverlapScores,
resp_tx: tokio::sync::oneshot::Sender<String>, resp_tx: tokio::sync::oneshot::Sender<i64>,
} }
impl SchedulingRequest { impl SchedulingRequest {
pub fn respond(self, worker_id: String) { pub fn respond(self, worker_id: i64) {
if self.resp_tx.send(worker_id).is_err() { if self.resp_tx.send(worker_id).is_err() {
tracing::trace!("failed to send response to requestor"); tracing::trace!("failed to send response to requestor");
} }
...@@ -174,7 +173,7 @@ impl KvScheduler { ...@@ -174,7 +173,7 @@ impl KvScheduler {
&self, &self,
overlap: OverlapScores, overlap: OverlapScores,
isl_tokens: usize, isl_tokens: usize,
) -> Result<String, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest { let request = SchedulingRequest {
isl_tokens, isl_tokens,
...@@ -199,7 +198,7 @@ impl KvScheduler { ...@@ -199,7 +198,7 @@ impl KvScheduler {
pub fn select_worker( pub fn select_worker(
workers: &mut ProcessedEndpoints, workers: &mut ProcessedEndpoints,
request: &SchedulingRequest, request: &SchedulingRequest,
) -> Result<String, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
// balance mode prioritizes balancing load across workers // balance mode prioritizes balancing load across workers
let balance_threshold: f64 = 0.1; let balance_threshold: f64 = 0.1;
let balance_mode = workers.load_std > balance_threshold * workers.load_avg; let balance_mode = workers.load_std > balance_threshold * workers.load_avg;
...@@ -227,6 +226,7 @@ pub fn select_worker( ...@@ -227,6 +226,7 @@ pub fn select_worker(
let kv_load_ratio = w.data.kv_active_blocks as f64 / w.data.kv_total_blocks as f64; let kv_load_ratio = w.data.kv_active_blocks as f64 / w.data.kv_total_blocks as f64;
let load_deviation = kv_load_ratio - workers.load_avg; let load_deviation = kv_load_ratio - workers.load_avg;
// [FIXME] multiple endpoints of the same worker cause out of bound error
let worker_id = workers.worker_ids[i]; let worker_id = workers.worker_ids[i];
let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x); let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x);
let overlap_score = overlap_score as usize * KV_BLOCK_SIZE; let overlap_score = overlap_score as usize * KV_BLOCK_SIZE;
...@@ -267,10 +267,10 @@ pub fn select_worker( ...@@ -267,10 +267,10 @@ pub fn select_worker(
Some(i) => { Some(i) => {
tracing::info!( tracing::info!(
"selected worker: {}; cost: {}", "selected worker: {}; cost: {}",
workers.endpoints[i].subject, workers.endpoints[i].worker_id(),
best_cost best_cost
); );
Ok(workers.endpoints[i].subject.clone()) Ok(workers.endpoints[i].worker_id())
} }
None => { None => {
tracing::debug!("all workers busy"); tracing::debug!("all workers busy");
......
...@@ -18,12 +18,11 @@ ...@@ -18,12 +18,11 @@
use std::collections::HashSet; use std::collections::HashSet;
use crate::kv_router::scheduler::Endpoint; use crate::kv_router::scheduler::Endpoint;
use uuid::Uuid;
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct ProcessedEndpoints { pub struct ProcessedEndpoints {
pub endpoints: Vec<Endpoint>, pub endpoints: Vec<Endpoint>,
pub worker_ids: Vec<Uuid>, pub worker_ids: Vec<i64>,
pub load_avg: f64, pub load_avg: f64,
pub load_std: f64, pub load_std: f64,
} }
...@@ -43,8 +42,8 @@ impl ProcessedEndpoints { ...@@ -43,8 +42,8 @@ impl ProcessedEndpoints {
/ load_values.len() as f64; / load_values.len() as f64;
let load_std = variance.sqrt(); let load_std = variance.sqrt();
let worker_ids: HashSet<Uuid> = endpoints.iter().map(|x| x.worker_id()).collect(); let worker_ids: HashSet<i64> = endpoints.iter().map(|x| x.worker_id()).collect();
let worker_ids: Vec<Uuid> = worker_ids.into_iter().collect(); let worker_ids: Vec<i64> = worker_ids.into_iter().collect();
ProcessedEndpoints { ProcessedEndpoints {
endpoints, endpoints,
......
...@@ -34,7 +34,6 @@ use std::time::SystemTime; ...@@ -34,7 +34,6 @@ use std::time::SystemTime;
use super::TokenIdType; use super::TokenIdType;
pub mod kv_routing;
pub mod llm_backend; pub mod llm_backend;
pub mod postprocessor; pub mod postprocessor;
pub mod preprocessor; pub mod preprocessor;
......
...@@ -74,9 +74,6 @@ addopts = [ ...@@ -74,9 +74,6 @@ addopts = [
"--mypy", "--mypy",
"--ignore-glob=*model.py", "--ignore-glob=*model.py",
# FIXME: Get relative/generic blob paths to work here # FIXME: Get relative/generic blob paths to work here
# Ignore rust<->python bindings until python package is built/installed in environment
"--ignore-glob=/workspace/python-wheel/python/triton_distributed_rs/*.py",
"--ignore-glob=/workspace/python-wheel/python/triton_distributed_rs/*.pyi",
] ]
xfail_strict = true xfail_strict = true
log_cli_level = "INFO" log_cli_level = "INFO"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment