Commit 861c5098 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files
parent eb022ec9
......@@ -181,23 +181,21 @@ index 1ca9e49d..b1591c0c 100644
# Reuse the cached content hash
diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py
index c5b3b04f..8a483aa2 100644
index c5b3b04f..c72001f7 100644
--- a/vllm/core/block_manager.py
+++ b/vllm/core/block_manager.py
@@ -9,10 +9,12 @@ from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block
@@ -10,7 +10,10 @@ from vllm.core.block.interfaces import Block
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
LastAccessBlocksTracker)
+from vllm.core.event_manager import KVCacheEventManager
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
+from vllm.core.event_manager import KVCacheEventManager
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
+from vllm.envs import (VLLM_KV_CAPI_PATH, VLLM_KV_COMPONENT, VLLM_KV_NAMESPACE,
+ VLLM_WORKER_ID)
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
+from vllm.envs import VLLM_WORKER_ID, VLLM_KV_CAPI_PATH
SeqId = int
EncoderSeqId = str
@@ -60,6 +62,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
@@ -60,6 +63,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
def __init__(
self,
......@@ -205,12 +203,23 @@ index c5b3b04f..8a483aa2 100644
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
@@ -91,11 +94,17 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
@@ -91,11 +95,28 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.watermark_blocks = int(watermark * num_gpu_blocks)
+ if VLLM_WORKER_ID is not None and VLLM_KV_CAPI_PATH is not None:
+ self.event_manager = KVCacheEventManager(model_name, worker_id=str(VLLM_WORKER_ID).encode(), lib_path=VLLM_KV_CAPI_PATH)
+ kv_event_manager_params = [
+ VLLM_WORKER_ID, VLLM_KV_CAPI_PATH, VLLM_KV_NAMESPACE,
+ VLLM_KV_COMPONENT
+ ]
+ set_kv_event_manager_params = len(
+ [param for param in kv_event_manager_params if param is not None])
+
+ if set_kv_event_manager_params == len(kv_event_manager_params):
+ self.event_manager = KVCacheEventManager(
+ namespace=VLLM_KV_NAMESPACE,
+ component=VLLM_KV_COMPONENT,
+ worker_id=VLLM_WORKER_ID,
+ lib_path=VLLM_KV_CAPI_PATH)
+ else:
+ self.event_manager = None
+
......@@ -225,36 +234,44 @@ index c5b3b04f..8a483aa2 100644
self.block_tables: Dict[SeqId, BlockTable] = {}
diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py
new file mode 100644
index 00000000..4aa90a4a
index 00000000..350453cd
--- /dev/null
+++ b/vllm/core/event_manager.py
@@ -0,0 +1,89 @@
+from typing import Optional
+import logging
+from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash
+
@@ -0,0 +1,102 @@
+# SPDX-License-Identifier: Apache-2.0
+import ctypes
+from ctypes import c_char_p, c_uint32, c_void_p, c_size_t
+import logging
+import uuid
+from ctypes import c_char_p, c_size_t, c_uint32, c_void_p, c_int64
+from typing import Optional
+
+from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash
+
+logger = logging.getLogger(__name__)
+
+
+class TritonResult:
+ OK = 0
+ ERR = 1
+
+
+class KVCacheEventManager:
+ def __init__(self, model_name: str, worker_id: bytes, lib_path: str):
+
+ def __init__(self, namespace: str, component: str, worker_id: int,
+ lib_path: str):
+ self.lib = None
+
+ try:
+ self.lib = ctypes.CDLL(lib_path)
+ self.lib.triton_llm_init.argtypes = [c_char_p, c_char_p]
+ self.lib.triton_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
+ self.lib.triton_llm_init.restype = c_uint32
+
+ result = self.lib.triton_llm_init(model_name.encode(), worker_id)
+ result = self.lib.triton_llm_init(namespace.encode(),
+ component.encode(), worker_id)
+ if result == TritonResult.OK:
+ logger.info("KVCacheEventManager initialized successfully. Ready to publish KV Cache Events")
+ logger.info(
+ "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
+ )
+ else:
+ logger.info("KVCacheEventManager initialization failed!")
+
......@@ -282,11 +299,14 @@ index 00000000..4aa90a4a
+
+ self.event_id_counter = 0
+
+ def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock], block: PrefixCachingBlock):
+ token_ids_arr = (ctypes.c_uint32 * len(block.token_ids))(*block.token_ids)
+ def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock],
+ block: PrefixCachingBlock):
+ token_ids_arr = (ctypes.c_uint32 *
+ len(block.token_ids))(*block.token_ids)
+ num_block_tokens = (ctypes.c_size_t * 1)(len(block.token_ids))
+ block_hash = (ctypes.c_uint64 * 1)(block.content_hash)
+ parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash) if parent is not None else None)
+ parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash)
+ if parent is not None else None)
+
+ # Publish the event
+ result = self.lib.triton_kv_event_publish_stored(
......@@ -302,7 +322,8 @@ index 00000000..4aa90a4a
+ if result == TritonResult.OK:
+ logger.debug(f"Store - Published KV Event: {block.content_hash}")
+ else:
+ logger.debug(f"Store - Failed to Publish KV Event: {block.content_hash}")
+ logger.debug(
+ f"Store - Failed to Publish KV Event: {block.content_hash}")
+
+ self.event_id_counter += 1
+
......@@ -310,7 +331,8 @@ index 00000000..4aa90a4a
+ result = self.lib.triton_kv_event_publish_removed(
+ self.event_id_counter,
+ (ctypes.c_uint64 * 1)(block_hash),
+ 1,)
+ 1,
+ )
+
+ if result == TritonResult.OK:
+ logger.debug(f"Remove - Published KV Event: {block_hash}")
......@@ -318,7 +340,6 @@ index 00000000..4aa90a4a
+ logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}")
+
+ self.event_id_counter += 1
\ No newline at end of file
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..6af77646 100644
--- a/vllm/core/scheduler.py
......@@ -356,7 +377,7 @@ index f507847a..6af77646 100644
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..61a357d0 100644
index fe480533..b768e03c 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
+++ b/vllm/distributed/kv_transfer/kv_connector/factory.py
@@ -27,13 +27,13 @@ class KVConnectorFactory:
......@@ -375,16 +396,6 @@ index fe480533..61a357d0 100644
# Register various connectors here.
@@ -48,3 +48,8 @@ KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
+
+KVConnectorFactory.register_connector(
+ "TritonNcclConnector",
+ "vllm.distributed.kv_transfer.kv_connector.triton_connector",
+ "TritonConnector")
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
index 2033e976..e33919c1 100644
--- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
......@@ -724,363 +735,6 @@ index 2033e976..e33919c1 100644
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/triton_connector.py b/vllm/distributed/kv_transfer/kv_connector/triton_connector.py
new file mode 100644
index 00000000..cb3b3660
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_connector/triton_connector.py
@@ -0,0 +1,350 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Simple KV Cache Connector for Distributed Machine Learning Inference
+
+The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
+producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
+MooncakePipe.
+
+But the logic can be extended to support other pipe and lookup buffer.
+"""
+import re
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
+import torch
+
+from vllm import _custom_ops as ops
+from vllm.config import VllmConfig, KVTransferConfig
+from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
+from vllm.distributed.utils import StatelessProcessGroup
+from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
+ SimpleBuffer)
+from vllm.logger import init_logger
+from vllm.sequence import IntermediateTensors
+
+if TYPE_CHECKING:
+ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
+
+logger = init_logger(__name__)
+
+
+class TritonConnector(KVConnectorBase):
+
+ def __init__(
+ self,
+ rank: int,
+ local_rank: int,
+ config: VllmConfig,
+ world_group,
+ ):
+
+ self.config = config.kv_transfer_config
+ self.tp_size = config.parallel_config.tensor_parallel_size
+ self.rank = rank
+
+ if self.config.kv_connector != "TritonNcclConnector":
+ raise NotImplementedError("Only TritonNcclConnector is supported by the TritonConnector class")
+
+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
+ PyNcclPipe)
+ from vllm.distributed.kv_transfer.kv_pipe.triton_nccl_pipe import (
+ TritonNcclDataPlane)
+
+ logger.info(
+ "Initializing TritonNcclConnector under kv_transfer_config %s",
+ self.config)
+
+ self.lookup_buffer_size = self.config.kv_buffer_size
+
+ self.producer_data_pipe: PyNcclPipe
+ self.consumer_data_pipe: PyNcclPipe
+ self.producer_signal_pipe: PyNcclPipe
+ self.consumer_signal_pipe: PyNcclPipe
+
+ self._broadcast_and_enhance_kv_config(rank, config, world_group)
+
+ self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
+ self.tp_size = config.parallel_config.tensor_parallel_size
+
+ # 2 pipes for every rank in the world
+ if self.config.is_kv_producer:
+ port_offset_base = rank + 1
+ else:
+ port_offset_base = rank // self.config.tensor_parallel_multiplier + 1
+
+
+ self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
+ self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config)
+
+ self.data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
+ local_rank=local_rank,
+ config=self.config,
+ port_offset=port_offset_base,
+ )
+
+ self.data_plane = TritonNcclDataPlane(
+ data_pipe=self.data_pipe,
+ port=self._get_data_plane_port(self.global_kv_rank),
+ )
+
+ def send_kv_caches_and_hidden_states(
+ self,
+ model_executable: torch.nn.Module,
+ model_input: "ModelInputForGPUWithSamplingMetadata",
+ kv_caches: List[torch.Tensor],
+ hidden_or_intermediate_states: Union[torch.Tensor,
+ IntermediateTensors],
+ ) -> None:
+
+ input_tokens_tensor = model_input.input_tokens
+ seq_lens = model_input.attn_metadata.seq_lens
+ slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
+ start_layer = model_executable.model.start_layer
+ end_layer = model_executable.model.end_layer
+ request_ids = list(model_input.request_ids_to_seq_ids.keys())
+
+ model_config = model_executable.model.config
+ is_deepseek = "deepseek" in model_config.architectures[0].lower()
+ if not is_deepseek:
+ num_heads = int(model_config.num_key_value_heads / self.tp_size)
+ hidden_size = model_config.hidden_size
+ num_attention_heads = model_config.num_attention_heads
+ head_size = int(hidden_size / num_attention_heads)
+ else:
+ num_heads = int(model_config.num_key_value_heads / self.tp_size)
+ hidden_size = model_config.hidden_size
+ num_attention_heads = model_config.num_attention_heads
+ head_size = int(4.5 * hidden_size / num_attention_heads)
+
+ # query_lens contains new KV caches that are added to vLLM.
+ # so we will send them to decode instance
+ # FIXME(Kuntai): This assume that all requests are prefill.
+ for idx, slen in enumerate(seq_lens):
+ start_pos = sum(seq_lens[:idx])
+ end_pos = start_pos + slen
+ current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx]
+ decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id)
+ decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config)
+
+ for target_rank in range(self.config.tensor_parallel_multiplier):
+
+ keys, values = [], []
+
+ for layer_id in range(start_layer, end_layer):
+ kv_cache = kv_caches[layer_id - start_layer]
+
+ current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
+
+ num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
+ head_start = target_rank * num_heads_per_rank
+ head_end = head_start + num_heads_per_rank
+
+ if not is_deepseek:
+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
+ keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
+ values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
+ else:
+ key_cache = kv_cache
+ keys.append(key_cache[current_slot_mapping].unsqueeze(0))
+ values.append(torch.empty(0))
+
+ keys = torch.cat(keys, dim=0)
+ values = torch.cat(values, dim=0)
+
+ decode_global_rank = decode_first_global_rank + target_rank
+ decode_port = self._get_data_plane_port(decode_global_rank)
+ partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos]
+ self._send(decode_hostname, decode_port, current_request_id, keys, values,
+ partial_hidden_or_intermediate_states)
+
+ logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
+
+ def recv_kv_caches_and_hidden_states(
+ self, model_executable: torch.nn.Module,
+ model_input: "ModelInputForGPUWithSamplingMetadata",
+ kv_caches: List[torch.Tensor]
+ ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
+ "ModelInputForGPUWithSamplingMetadata"]:
+
+ # When bypass_model_exec is set to False, it means that at least for one
+ # request its corresponding KV cache or hidden state is missing.
+ # In this case we need to do prefilling to recompute missing KV cache
+ # and hidden states.
+ bypass_model_exec = True
+
+ input_tokens_tensor = model_input.input_tokens
+ seq_lens = model_input.attn_metadata.seq_lens
+ slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
+ request_ids = list(model_input.request_ids_to_seq_ids.keys())
+
+ hidden_or_intermediate_states_for_one_req = []
+
+ input_tokens_list = []
+ start_pos_list = []
+
+ model_config = model_executable.model.config
+ is_deepseek = "deepseek" in model_config.architectures[0].lower()
+
+ # enumerate different requests
+ # FIXME(Kuntai): This impl assumes that all requests are prefill.
+ for idx, slen in enumerate(seq_lens):
+
+ start_pos = sum(seq_lens[:idx])
+ end_pos = start_pos + slen
+ current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx]
+ num_tokens = slen
+
+ # collecting data for rebuilding the input
+ input_tokens_list.append(current_tokens)
+ start_pos_list.append(start_pos)
+
+ ret = self._recv(current_request_id)
+ keys: torch.Tensor = ret[0]
+ values: torch.Tensor = ret[1]
+ hidden: torch.Tensor = ret[2]
+
+ # put received KV caches into paged memory
+ for i in range(model_executable.model.start_layer,
+ model_executable.model.end_layer):
+
+ kv_cache = kv_caches[i - model_executable.model.start_layer]
+ layer = model_executable.model.layers[i]
+
+ if not is_deepseek:
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
+ ops.reshape_and_cache_flash(
+ keys[i - model_executable.model.start_layer].to(
+ key_cache.device),
+ values[i - model_executable.model.start_layer].to(
+ value_cache.device),
+ key_cache,
+ value_cache,
+ slot_mapping[start_pos:end_pos],
+ layer.self_attn.attn.kv_cache_dtype,
+ layer.self_attn.attn._k_scale,
+ layer.self_attn.attn._v_scale,
+ )
+ else:
+ key_cache = kv_cache
+ copy_from =keys[i - model_executable.model.start_layer].to(
+ key_cache.device)
+ kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
+
+ hidden_or_intermediate_states_for_one_req.append(hidden)
+
+ if not bypass_model_exec:
+ # Some of the KV cache is not retrieved
+ # Here we will fall back to normal model forwarding
+ # But optionally you can adjust model_input so that you only do
+ # prefilling on those tokens that are missing KV caches.
+ logger.debug(
+ "[rank%d]: Failed to receive all KVs and hidden "
+ "states, redo model forwarding.", torch.distributed.get_rank())
+ hidden_or_intermediate_states = None
+
+ else:
+ logger.debug(
+ "[rank%d]: Successfully received all KVs and hidden "
+ "states, skip model forwarding.", torch.distributed.get_rank())
+ hidden_or_intermediate_states = torch.cat(
+ hidden_or_intermediate_states_for_one_req, dim=0)
+
+ return hidden_or_intermediate_states, bypass_model_exec, model_input
+
+ def close(self):
+ self.data_pipe.close()
+ # self.data_plane.close()
+
+ @staticmethod
+ def parse_request_id(request_id: str) -> Tuple[str, int]:
+ # Regular expression to match the string hostname and integer decode_kv_rank
+ pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)"
+
+ # Use re.search to find the pattern in the request_id
+ match = re.search(pattern, request_id)
+ if match:
+ # Extract the ranks
+ decode_hostname = match.group(1)
+ decode_rank = int(match.group(2))
+
+ return decode_hostname, decode_rank
+ raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank")
+
+ def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor):
+ remote_address = f"{hostname}:{port}"
+ self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address)
+ self.data_plane.send_tensor(values, f"{request_id}_values", remote_address)
+ self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address)
+
+ def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ keys = self.data_plane.recv_tensor(f"{request_id}_keys")
+ values = self.data_plane.recv_tensor(f"{request_id}_values")
+ hidden = self.data_plane.recv_tensor(f"{request_id}_hidden")
+ return keys, values, hidden
+
+ def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if kv_rank < config.kv_producers_parallel_size:
+ return kv_rank
+
+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
+
+
+ def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if kv_rank <= config.kv_producers_parallel_size:
+ return kv_rank * config.kv_producers_tensor_parallel_size + rank
+
+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank
+
+
+ def _get_data_plane_port(self, global_kv_rank: int) -> int:
+ return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank
+
+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
+ if rank == 0:
+ config_group = StatelessProcessGroup.create(
+ host=self.config.kv_ip,
+ port=self.config.kv_port,
+ rank=self.config.kv_rank,
+ world_size=self.config.kv_parallel_size,
+ )
+ parallel_configs = config_group.all_gather_obj({
+ "kv_role": self.config.kv_role,
+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size,
+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
+ })
+ logger.debug("parallel_configs: %s", parallel_configs)
+ kv_config_enhanced = {
+ "kv_producers_tensor_parallel_size": None,
+ "kv_consumers_tensor_parallel_size": None,
+ "kv_producers_pipeline_parallel_size": None,
+ "kv_consumers_pipeline_parallel_size": None,
+ "kv_producers_parallel_size": 0,
+ }
+ for parallel_config in parallel_configs:
+ kv_role = parallel_config["kv_role"]
+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
+
+ if kv_role == "kv_producer":
+ kv_config_enhanced["kv_producers_parallel_size"] += 1
+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
+ else:
+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
+ world_group.broadcast_object(kv_config_enhanced)
+ else:
+ kv_config_enhanced = world_group.broadcast_object()
+ logger.info("kv_config_enhanced: %s", kv_config_enhanced)
+
+ self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"]
+ self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"]
+ self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
index 5e1b6235..b4506877 100644
--- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
......@@ -1361,7 +1015,7 @@ index 40589fb3..da2829cf 100644
Returns:
diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
index 7aa53d07..f5dd50b7 100644
index 7aa53d07..db10f8a0 100644
--- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
+++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
@@ -45,33 +45,33 @@ class PyNcclPipe(KVPipeBase):
......@@ -1496,16 +1150,13 @@ index 7aa53d07..f5dd50b7 100644
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
@@ -241,32 +243,39 @@ class PyNcclPipe(KVPipeBase):
with self.buffer_size_lock:
@@ -242,19 +244,23 @@ class PyNcclPipe(KVPipeBase):
self.buffer_size += tensor_size
- self.transport_thread.submit(self.send_tensor_wrapper, tensor,
self.transport_thread.submit(self.send_tensor_wrapper, tensor,
- tensor_size)
+ future = self.transport_thread.submit(self.send_tensor_wrapper, tensor,
+ tensor_size,
+ target_rank)
+ return future
- def recv_tensor(self) -> Optional[torch.Tensor]:
+ def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
......@@ -1524,162 +1175,8 @@ index 7aa53d07..f5dd50b7 100644
- future = self.transport_thread.submit(self._recv_impl)
+ future = self.transport_thread.submit(self._recv_impl, src_rank)
- try:
- tensor = future.result()
- except Exception as e:
- logger.error("Encountering exception in KV receiving thread")
- logger.error("%s", e)
- logger.error("My device: %s", self.device)
- import traceback
- traceback.print_exc()
- raise e
+ return future
+
+ # try:
+ # tensor = future.result()
+ # except Exception as e:
+ # logger.error("Encountering exception in KV receiving thread")
+ # logger.error("%s", e)
+ # logger.error("My device: %s", self.device)
+ # import traceback
+ # traceback.print_exc()
+ # raise e
- return tensor
+ # return tensor
def close(self):
"""
diff --git a/vllm/distributed/kv_transfer/kv_pipe/triton_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/triton_nccl_pipe.py
new file mode 100644
index 00000000..8a356504
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_pipe/triton_nccl_pipe.py
@@ -0,0 +1,124 @@
+import logging
+import threading
+import typing
+import zmq
+import socket
+import time
+import torch
+
+from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
+
+
+logger = logging.getLogger(__name__)
+
+
+class TritonNcclDataPlane:
+ def __init__(
+ self,
+ data_pipe: PyNcclPipe,
+ hostname: str = "",
+ port: int = 0,
+ ) -> None:
+
+ self.data_pipe = data_pipe
+ if not hostname:
+ hostname = socket.gethostname()
+ if port == 0:
+ raise ValueError("Port cannot be 0")
+ self._hostname = hostname
+ self._port = port
+ self.store = {}
+ self.context = zmq.Context()
+ self.rep_socket = self.context.socket(zmq.REP)
+ logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}")
+ self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
+ self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True)
+ self._listener_thread.start()
+ self.req_sockets = {}
+ logger.info(f"Rank {self.rank} connected to the server")
+
+ @property
+ def rank(self):
+ return self.data_pipe.kv_group_rank
+
+ def send_tensor(
+ self,
+ tensor: torch.Tensor,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ):
+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}")
+ return self._send_tensor(tensor, tensor_id, remote_address)
+
+ def recv_tensor(
+ self,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ) -> torch.Tensor:
+ ret = self._recv_tensor(tensor_id, remote_address)
+ return ret
+
+ def _send_tensor(
+ self,
+ tensor: torch.Tensor,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ):
+ logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}")
+ if remote_address is None:
+ self.store[tensor_id] = tensor
+ else:
+ # tensor_shape = "_".join(str(dim) for dim in tensor.shape)
+ # tensor_dtype = str(tensor.dtype)
+ if remote_address not in self.req_sockets:
+ self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
+ self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
+
+ req_socket = self.req_sockets[remote_address]
+ # req_socket.connect(f"tcp://{remote_address}")
+ req_socket.send_string(f"PUT {self.rank} {tensor_id}")
+ dst_rank = req_socket.recv_string()
+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}")
+ self.data_pipe.send_tensor(tensor, int(dst_rank))
+
+ def _recv_tensor(
+ self,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ) -> torch.Tensor:
+ logger.debug(f"Rank {self.rank} receiving tensor")
+ if remote_address is not None:
+ raise NotImplementedError("Getting tensor from remote rank not implemented")
+ if tensor_id in self.store:
+ logger.debug(f"Popping tensor {tensor_id} from store")
+ future = self.store.pop(tensor_id)
+ tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait
+ logger.debug(f"Rank {self.rank} received tensor")
+ return tensor
+
+ logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}")
+ time.sleep(0.001)
+ return self._recv_tensor(tensor_id, remote_address)
+ # raise NotImplementedError("Tensor not found in store")
+
+ def _receive_tensor(
+ self,
+ tensor_id: str,
+ rank: int,
+ ):
+ future = self.data_pipe.recv_tensor(rank)
+ logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store")
+ self.store[tensor_id] = future
+
+ def listen_for_requests(self):
+ while True:
+ cmd, rank, tensor_id = self.rep_socket.recv_string().split()
+ logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}")
+ self.rep_socket.send_string(f"{self.rank}")
+ if cmd == "GET":
+ raise NotImplementedError("Getting tensor from remote rank not implemented")
+ elif cmd == "PUT":
+ rank = int(rank)
+ # shape = [int(dim) for dim in shape.split("_")]
+ # dtype = getattr(torch, dtype)
+ self._receive_tensor(tensor_id, rank)
try:
tensor = future.result()
diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py
index 1e80e0bd..cd90206f 100644
--- a/vllm/distributed/kv_transfer/kv_transfer_agent.py
......@@ -1728,20 +1225,255 @@ index d82d9ad9..542ccfe8 100644
self.parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id]
if self.model_config.use_async_output_proc else None)
diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py
index 3cf1850e..38acca0e 100644
--- a/vllm/engine/multiprocessing/__init__.py
+++ b/vllm/engine/multiprocessing/__init__.py
@@ -21,6 +21,7 @@ IPC_INPUT_EXT = "_input_socket"
IPC_OUTPUT_EXT = "_output_socket"
IPC_HEALTH_EXT = "_health_socket"
IPC_DATA_EXT = "_data_socket"
+IPC_METRICS_EXT = "_metrics_socket"
class MQEngineDeadError(RuntimeError):
@@ -157,3 +158,10 @@ def ENGINE_DEAD_ERROR(
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
f"find the original error: {repr(error)}.")
+
+@dataclass
+class KvMetrics:
+ request_active_slots: int
+ request_total_slots: int
+ kv_active_blocks: int
+ kv_total_blocks: int
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..6a7ea3ae 100644
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -25,14 +25,15 @@ from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
- IPC_OUTPUT_EXT, RPC_REQUEST_T,
+ IPC_OUTPUT_EXT, IPC_METRICS_EXT,
+ RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
- RPCUProfileRequest)
+ RPCUProfileRequest, KvMetrics)
from vllm.engine.protocol import EngineClient
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
@@ -115,6 +116,10 @@ class MQLLMEngineClient(EngineClient):
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
+ # Metrics.
+ self.metrics_socket: Socket = self.context.socket(zmq.constants.PULL)
+ self.metrics_socket.connect(f"{ipc_path}{IPC_METRICS_EXT}")
+
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@@ -129,6 +134,12 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
+
+ # Loop to check metrics of the LLMEngine periodically.
+ # Started after the MQLLMEngine is ready.
+ self.metrics_loop: Optional[asyncio.Task] = None
+ self.metrics_publisher = None
+
self._engine_process = psutil.Process(engine_pid)
@staticmethod
@@ -180,6 +191,46 @@ class MQLLMEngineClient(EngineClient):
except Exception as e:
self._set_errored(e)
+ async def run_metrics_loop(self, timeout: int):
+ """Background loop that continually checks to ensure the engine process
+ is still alive.
+ """
+ try:
+ while True:
+ # Check if the engine process is running:
+ if not self._engine_process.is_running() or (
+ self._engine_process.status() == psutil.STATUS_ZOMBIE):
+ # NB: is_running() returns True for zombies
+ self._set_errored(
+ RuntimeError(
+ f"Engine process (pid {self._engine_process.pid}) "
+ "died."))
+ break
+
+ if await self.metrics_socket.poll(timeout=timeout):
+ # Metrics received- check the message
+ message: Frame = await self.metrics_socket.recv(copy=False)
+ kv_metrics = pickle.loads(message.buffer)
+ if self.metrics_publisher is not None:
+ if isinstance(kv_metrics, KvMetrics):
+ self.metrics_publisher.publish(kv_metrics.request_active_slots,
+ kv_metrics.request_total_slots,
+ kv_metrics.kv_active_blocks,
+ kv_metrics.kv_total_blocks)
+
+ logger.debug("Metircs successful.")
+
+ except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
+
+ except psutil.NoSuchProcess:
+ self._set_errored(
+ RuntimeError(
+ f"Engine process (pid {self._engine_process.pid}) died."))
+
+ except Exception as e:
+ self._set_errored(e)
+
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
@@ -284,6 +335,12 @@ class MQLLMEngineClient(EngineClient):
if self.health_loop is None:
self.health_loop = asyncio.create_task(
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
+
+ # Start metrics_loop.
+ if self.metrics_loop is None:
+ self.metrics_loop = asyncio.create_task(
+ self.run_metrics_loop(timeout=VLLM_RPC_TIMEOUT))
+
def close(self):
"""Destroy the ZeroMQ Context."""
@@ -293,6 +350,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
+ if self.metrics_loop is not None:
+ self.metrics_loop.cancel()
if self.output_loop is not None:
self.output_loop.cancel()
@@ -705,3 +764,6 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output
+
+ def set_metrics_publisher(self, metrics_publisher):
+ self.metrics_publisher = metrics_publisher
diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py
index a0dd7958..dc6ea25d 100644
--- a/vllm/engine/multiprocessing/engine.py
+++ b/vllm/engine/multiprocessing/engine.py
@@ -14,24 +14,56 @@ from vllm.engine.llm_engine import LLMEngine
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
- IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
+ IPC_OUTPUT_EXT, IPC_METRICS_EXT,
+ REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
- RPCUProfileRequest)
+ RPCUProfileRequest, KvMetrics)
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
+from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo
+from dataclasses import dataclass, field
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
+class KvStatLogger(StatLoggerBase):
+ def __init__(
+ self,
+ max_num_seqs: int,
+ num_total_gpu_blocks: int,
+ metrics_socket
+ ):
+ # Must query initialized scheduler for max infos
+ self.request_total_slots = max_num_seqs
+ self.kv_total_blocks = num_total_gpu_blocks
+ self.metrics_socket = metrics_socket
+
+ # KV metrics
+ self._send_kv_metrics(0, 0)
+
+ def log(self, stats: Stats) -> None:
+ self._send_kv_metrics(
+ stats.num_running_sys,
+ int(stats.gpu_cache_usage_sys * self.kv_total_blocks)
+ )
+
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+ pass
+
+ def _send_kv_metrics(self, active_slots, active_kv_blocks):
+ if not self.metrics_socket.closed:
+ metrics_bytes = pickle.dumps(KvMetrics(active_slots, self.request_total_slots, active_kv_blocks, self.kv_total_blocks))
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`.
@@ -94,12 +126,24 @@ class MQLLMEngine:
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
+ # Send metrics back to client.
+ self.metrics_socket = self.ctx.socket(zmq.constants.PUSH)
+ self.metrics_socket.bind(f"{ipc_path}{IPC_METRICS_EXT}")
+
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Error state.
self._errored_with: Optional[BaseException] = None
+ # Attach logger for continuous metrics publishing
+ self.stat_logger = KvStatLogger(
+ self.engine.scheduler_config.max_num_seqs,
+ self.engine.cache_config.num_gpu_blocks,
+ self.metrics_socket
+ )
+ self.engine.add_logger("kv_metrics", self.stat_logger)
+
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
diff --git a/vllm/envs.py b/vllm/envs.py
index 745b068b..438142e3 100644
index 745b068b..0ae63d9b 100644
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -87,6 +87,8 @@ if TYPE_CHECKING:
@@ -87,6 +87,10 @@ if TYPE_CHECKING:
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
+ VLLM_KV_CAPI_PATH: Optional[str] = None
+ VLLM_WORKER_ID: Optional[str] = None
+ VLLM_KV_NAMESPACE: Optional[str] = None
+ VLLM_KV_COMPONENT: Optional[str] = None
+ VLLM_WORKER_ID: Optional[int] = None
def get_default_cache_root():
@@ -572,6 +574,14 @@ environment_variables: Dict[str, Callable[[], Any]] = {
@@ -572,6 +576,21 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# models the alignment is already naturally aligned to 256 bytes.
"VLLM_CUDA_MEM_ALIGN_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
......@@ -1750,9 +1482,16 @@ index 745b068b..438142e3 100644
+ "VLLM_KV_CAPI_PATH":
+ lambda: os.environ.get("VLLM_KV_CAPI_PATH", None),
+
+ # Identifiers to publish KV related information
+ "VLLM_KV_NAMESPACE":
+ lambda: os.environ.get("VLLM_KV_NAMESPACE", None),
+ "VLLM_KV_COMPONENT":
+ lambda: os.environ.get("VLLM_KV_COMPONENT", None),
+
+ # Worker ID used for identifying workers in distributed settings
+ "VLLM_WORKER_ID":
+ lambda: os.getenv("VLLM_WORKER_ID", None),
+ lambda: int(os.getenv("VLLM_WORKER_ID", "0"))
+ if "VLLM_WORKER_ID" in os.environ else None,
}
# end-env-vars-definition
......
......@@ -180,7 +180,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>]
Example:
```bash
# Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model
bash /workspace/examples/python_rs/llm/vllm/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B
bash /workspace/examples/python_rs/llm/vllm/scripts/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# List tmux sessions
tmux ls
......@@ -252,7 +252,7 @@ llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B triton-init
```
```bash
curl localhost:9992/v1/chat/completions -H "Content-Type: application/json" -d '{
curl localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [
{
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uvloop
from common.protocol import Request, Response
from vllm.logger import logger as vllm_logger
from triton_distributed.llm import KvRouter
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
router,
workers_client,
):
self.router = router
self.workers_client = workers_client
@triton_endpoint(Request, Response)
async def generate(self, request):
lora_id = 0
worker_id = None
tokens = [3] * 64
try:
worker_id = await self.router.schedule(tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"got exception of type {type(e)}: {e}")
worker_id = None
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
if worker_id is None:
vllm_logger.info("randomly select worker")
engine_generator = await self.workers_client.random(
request.model_dump_json()
)
else:
vllm_logger.info(f"directly select worker: {worker_id}")
engine_generator = await self.workers_client.direct(
request.model_dump_json(), worker_id
)
async for resp in engine_generator:
resp = resp.data() if hasattr(resp, "data") else resp
yield resp
@triton_endpoint(Request, Response)
async def mock_generate(self, request):
print(f"Received request: {request}")
yield "Hello, World!"
ROUTE_SELF = True
@triton_worker()
async def worker(runtime: DistributedRuntime):
workers_client = (
await runtime.namespace("triton-init")
.component("vllm")
.endpoint("generate")
.client()
)
vllm_logger.info(
f"Have number of workers ({len(workers_client.endpoint_ids())}) are ready:\n"
+ "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
)
# [TODO] Collect endpoint implementation expects services to provide
# ForwardPassMetrics as part of stats handling and it will panic if
# otherwise. This needs to be fixed so that non-providing endpoints will
# simply be ignored, but before that, we will make sure that the services
# of the same namespace::component are created via KvMetricsPublisher,
# if it is also used to create endpoints.
kv_listener = runtime.namespace("triton-init").component("vllm")
await kv_listener.create_service()
router = KvRouter(runtime, kv_listener)
# i.e. below will cause panic
# endpoint = kv_listener.endpoint("generate")
# await endpoint.serve_endpoint(
# Router(router, workers_client).mock_generate
# )
router_component = runtime.namespace("triton-init").component("frontend")
await router_component.create_service()
endpoint = router_component.endpoint("generate")
await endpoint.serve_endpoint(Router(router, workers_client).generate)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import ctypes
from ctypes import c_char_p, c_int64, c_uint32
import uvloop
from common.protocol import Request, Response
from vllm.logger import logger as vllm_logger
from triton_distributed.llm import KvMetricsPublisher
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
class TritonResult:
OK = 0
ERR = 1
class MockEngine:
"""
Request handler for the generate endpoint
"""
def __init__(self, metrics_publisher, worker_id):
self.worker_id = worker_id
# KV events
self.lib = ctypes.CDLL("/opt/triton/llm_binding/lib/libtriton_llm_capi.so")
self.lib.triton_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
self.lib.triton_llm_init.restype = c_uint32
result = self.lib.triton_llm_init(
"triton-init".encode(), "vllm".encode(), worker_id
)
if result == TritonResult.OK:
vllm_logger.info(
"KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
)
else:
vllm_logger.info("KVCacheEventManager initialization failed!")
self.lib.triton_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids
ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
ctypes.POINTER(ctypes.c_uint64), # parent_hash
ctypes.c_uint64, # lora_id
]
self.lib.triton_kv_event_publish_stored.restype = (
ctypes.c_uint32
) # triton_llm_result_t
self.lib.triton_kv_event_publish_removed.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
]
self.lib.triton_kv_event_publish_removed.restype = (
ctypes.c_uint32
) # triton_llm_result_t
# KV metrics
self.metrics_publisher = metrics_publisher
self.request_active_slots = 0
self.request_total_slots = 4
self.kv_active_block = 0
self.kv_total_blocks = 4
# [NOTE] Now that the component must has proper metrics reported
# to be properly selected by the router
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
self.event_id_counter = 0
self.tokens = [3] * 64
@triton_endpoint(Request, Response)
async def generate(self, request):
print(f"Received request: {request}")
self.request_active_slots = min(
self.request_active_slots + 1, self.request_total_slots
)
self.kv_active_block = min(self.kv_active_block + 1, self.kv_total_blocks)
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
self.store_event()
yield "Hello, World!"
def store_event(self):
parent_hash = (
(ctypes.c_uint64 * 1)(self.event_id_counter)
if self.event_id_counter > 0
else None
)
result = self.lib.triton_kv_event_publish_stored(
self.event_id_counter, # uint64_t event_id
(ctypes.c_uint32 * len(self.tokens))(
*self.tokens
), # const uint32_t *token_ids
(ctypes.c_size_t * 1)(
len(self.tokens)
), # const uintptr_t *num_block_tokens
(ctypes.c_uint64 * 1)(self.event_id_counter), # const uint64_t *block_ids
1, # uintptr_t num_blocks
parent_hash, # const uint64_t *parent_hash
0, # uint64_t lora_id
)
self.event_id_counter += 1
if result == TritonResult.OK:
vllm_logger.debug(f"Store - Published KV Event: {self.event_id_counter}")
else:
vllm_logger.debug(
f"Store - Failed to Publish KV Event: {self.event_id_counter}"
)
async def cooldown(self):
while True:
await asyncio.sleep(5)
self.request_active_slots = max(0, self.request_active_slots - 1)
self.kv_active_block = max(0, self.kv_active_block - 1)
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
@triton_worker()
async def worker(runtime: DistributedRuntime):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("triton-init").component("vllm")
metrics_publisher = KvMetricsPublisher()
await metrics_publisher.create_service(component)
endpoint = component.endpoint("generate")
engine = MockEngine(metrics_publisher, endpoint.lease_id())
await asyncio.gather(
engine.cooldown(),
endpoint.serve_endpoint(engine.generate),
)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
......@@ -31,8 +31,8 @@ from vllm.logger import logger as vllm_logger
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from triton_distributed._core import Client
from triton_distributed.runtime import (
Client,
DistributedRuntime,
triton_endpoint,
triton_worker,
......@@ -126,7 +126,7 @@ class Processor(ProcessMixIn):
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json(),
uuid.UUID(worker_id).int,
int(worker_id),
)
output = self.generate_responses(engine_generator)
......
......@@ -58,20 +58,21 @@ class Router:
@triton_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = ""
worker_id = None
if self.routing_strategy == RoutingStrategy.PREFIX:
try:
worker_id = await self.router.schedule(request.tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"{e}")
if "No worker found" in str(e):
worker_id = ""
else:
worker_id = None
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
yield worker_id
yield str(worker_id)
else:
# TODO: Do we implement round_robin and random here?
......@@ -113,8 +114,7 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
+ "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
)
# TODO Router is a fixed namespace separate from the others
kv_listener = runtime.namespace("router").component(args.model_name)
kv_listener = runtime.namespace("triton-init").component("vllm")
await kv_listener.create_service()
router_component = runtime.namespace("triton-init").component("router")
......
......@@ -15,7 +15,6 @@
import asyncio
import os
import uuid
from typing import AsyncIterator
import uvloop
......@@ -26,6 +25,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import logger as vllm_logger
from vllm.sampling_params import RequestOutputKind
from triton_distributed.llm import KvMetricsPublisher
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
......@@ -40,10 +40,18 @@ class VllmEngine(BaseVllmEngine):
vLLM Inference Engine
"""
def __init__(self, engine_args: AsyncEngineArgs):
def __init__(
self, engine_args: AsyncEngineArgs, metrics_publisher: KvMetricsPublisher
):
self.metrics_publisher = metrics_publisher
self.engine_args = engine_args
super().__init__(engine_args)
async def initialize(self):
await super().initialize()
assert self.engine_client is not None, "engine_client was not initialized"
self.engine_client.set_metrics_publisher(self.metrics_publisher)
@triton_endpoint(vLLMGenerateRequest, MyRequestOutput)
async def generate(self, request) -> AsyncIterator:
assert (
......@@ -74,21 +82,32 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Serve the triton-init.vllm.generate endpoint.
"""
metrics_publisher = KvMetricsPublisher()
worker_component = runtime.namespace("triton-init").component("vllm")
await worker_component.create_service()
await metrics_publisher.create_service(worker_component)
worker_endpoint = worker_component.endpoint("generate")
# KV Publisher and Aggregator requires a UUID (str)
# KV Router requires a lease_id (int)
# This allows us to please both, until they are unified
# If VLLM_WORKER_ID is not set, KV Routing will fail
VLLM_WORKER_ID = uuid.UUID(int=worker_endpoint.lease_id())
VLLM_WORKER_ID = worker_endpoint.lease_id()
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
vllm_engine = VllmEngine(engine_args)
VLLM_KV_NAMESPACE = "triton-init"
os.environ["VLLM_KV_NAMESPACE"] = str(VLLM_KV_NAMESPACE)
VLLM_KV_COMPONENT = "vllm"
os.environ["VLLM_KV_COMPONENT"] = str(VLLM_KV_COMPONENT)
vllm_engine = VllmEngine(engine_args, metrics_publisher)
await vllm_engine.initialize()
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
metrics_publisher.publish(
0,
1024,
0,
1024,
)
await worker_endpoint.serve_endpoint(vllm_engine.generate)
......
......@@ -22,13 +22,13 @@ use tracing as log;
use uuid::Uuid;
use triton_distributed_llm::kv_router::{
indexer::compute_block_hash_for_seq, protocols::*, publisher::KvPublisher,
indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher,
};
use triton_distributed_runtime::{DistributedRuntime, Worker};
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
static KV_PUB: OnceCell<KvPublisher> = OnceCell::new();
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
fn initialize_tracing() {
// Sets up RUST_LOG environment variable for logging while KV Publishing
......@@ -49,11 +49,12 @@ pub enum TritonLlmResult {
}
/// # Safety
/// the model_name_c_str and worker_id_c_str are passed as pointers to C strings
/// the namespace_c_str and component_c_str are passed as pointers to C strings
#[no_mangle]
pub unsafe extern "C" fn triton_llm_init(
model_name_c_str: *const c_char,
worker_id_c_str: *const c_char,
namespace_c_str: *const c_char,
component_c_str: *const c_char,
worker_id: i64,
) -> TritonLlmResult {
initialize_tracing();
let wk = match WK.get_or_try_init(Worker::from_settings) {
......@@ -78,7 +79,7 @@ pub unsafe extern "C" fn triton_llm_init(
}
}
});
let model_name = match unsafe { CStr::from_ptr(model_name_c_str) }.to_str() {
let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
Ok(s) => s.to_string(),
Err(e) => {
eprintln!("Failed to convert C string to Rust string: {:?}", e);
......@@ -86,24 +87,17 @@ pub unsafe extern "C" fn triton_llm_init(
}
};
let worker_id_str = match unsafe { CStr::from_ptr(worker_id_c_str) }.to_str() {
Ok(s) => s,
let component = match unsafe { CStr::from_ptr(component_c_str) }.to_str() {
Ok(s) => s.to_string(),
Err(e) => {
eprintln!("Failed to convert C string to Rust string: {:?}", e);
return TritonLlmResult::ERR;
}
};
let worker_id_uuid = match Uuid::parse_str(worker_id_str) {
Ok(uuid) => uuid,
Err(e) => {
eprintln!("Failed to parse worker_id as UUID: {:?}", e);
return TritonLlmResult::ERR;
}
};
match result {
Ok(_) => match KV_PUB
.get_or_try_init(move || triton_create_kv_publisher(model_name, worker_id_uuid))
.get_or_try_init(move || triton_create_kv_publisher(namespace, component, worker_id))
{
Ok(_) => TritonLlmResult::OK,
Err(e) => {
......@@ -143,17 +137,18 @@ pub extern "C" fn triton_llm_load_publisher_create() -> TritonLlmResult {
// c++ executor api
fn triton_create_kv_publisher(
model_name: String,
worker_id: Uuid,
) -> Result<KvPublisher, anyhow::Error> {
log::info!("Creating KV Publisher for model: {}", model_name);
namespace: String,
component: String,
worker_id: i64,
) -> Result<KvEventPublisher, anyhow::Error> {
log::info!("Creating KV Publisher for model: {}", component);
match DRT
.get()
.ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
{
Ok(drt) => {
let backend = drt.namespace("router")?.component(model_name)?;
KvPublisher::new(drt.clone(), backend, worker_id)
let backend = drt.namespace(namespace)?.component(component)?;
KvEventPublisher::new(drt.clone(), backend, worker_id)
}
Err(e) => Err(e),
}
......
......@@ -64,6 +64,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Client>()?;
m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::kv::KvMetricsPublisher>()?;
engine::add_to_module(m)?;
......
......@@ -23,6 +23,7 @@ pub(crate) struct KvRouter {
#[pymethods]
impl KvRouter {
#[new]
// [FXIME] 'drt' can be obtained from 'component'
fn new(drt: DistributedRuntime, component: Component) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
......@@ -44,11 +45,64 @@ impl KvRouter {
) -> PyResult<Bound<'p, PyAny>> {
let router = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let uuid = router
let worker_id = router
.schedule(&token_ids, lora_id)
.await
.map_err(to_pyerr)?;
Ok(uuid.to_string())
Ok(worker_id)
})
}
}
#[pyclass]
pub(crate) struct KvMetricsPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvMetricsPublisher>,
}
#[pymethods]
impl KvMetricsPublisher {
#[new]
fn new() -> PyResult<Self> {
let inner = llm_rs::kv_router::publisher::KvMetricsPublisher::new().map_err(to_pyerr)?;
Ok(Self {
inner: inner.into(),
})
}
fn create_service<'p>(
&self,
py: Python<'p>,
component: Component,
) -> PyResult<Bound<'p, PyAny>> {
let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let _ = rs_publisher
.create_service(rs_component)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
fn publish<'p>(
&self,
py: Python<'p>,
request_active_slots: u64,
request_total_slots: u64,
kv_active_blocks: u64,
kv_total_blocks: u64,
) -> PyResult<()> {
self.inner
.publish(
llm_rs::kv_router::protocols::ForwardPassMetrics {
request_active_slots,
request_total_slots,
kv_active_blocks,
kv_total_blocks,
}
.into(),
)
.map_err(to_pyerr)
}
}
......@@ -128,7 +128,7 @@ class Client:
class KvRouter:
"""
The runtime object for a distributed NOVA applications
A router will determine which worker should handle a given request.
"""
...
......@@ -138,9 +138,36 @@ class KvRouter:
Create a `KvRouter` object that is associated with the `component`
"""
def schedule(self, token_ids: List[int], lora_id: int) -> str:
def schedule(self, token_ids: List[int], lora_id: int) -> int:
"""
Return the worker id that should handle the given token ids,
exception will be raised if there is no worker available.
"""
...
class KvMetricsPublisher:
"""
A metrics publisher will provide KV metrics to the router.
"""
...
def __init__(self) -> None:
"""
Create a `KvMetricsPublisher` object
"""
def create_service(self, component: Component) -> None:
"""
Similar to Component.create_service, but only service created through
this method will interact with KV router of the same component.
"""
def publish(self, request_active_slots: int,
request_total_slots: int,
kv_active_blocks: int,
kv_total_blocks: int) -> None:
"""
Update the KV metrics being reported.
"""
...
......@@ -13,4 +13,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from triton_distributed._core import KvMetricsPublisher as KvMetricsPublisher
from triton_distributed._core import KvRouter as KvRouter
......@@ -20,7 +20,8 @@ from typing import Any, AsyncGenerator, Callable, Type
from pydantic import BaseModel, ValidationError
from triton_distributed._core import DistributedRuntime
from triton_distributed._core import Client as Client
from triton_distributed._core import DistributedRuntime as DistributedRuntime
def triton_worker():
......
......@@ -23,14 +23,11 @@ use triton_distributed_runtime::{component::Component, DistributedRuntime};
pub mod indexer;
pub mod protocols;
pub mod publisher;
// [WIP] enable service_builder() through worker for metrics reporting
// pub mod worker;
mod scheduler;
mod scoring;
use crate::kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
protocols::KV_BLOCK_SIZE,
scheduler::{Endpoint, KvScheduler, Service},
scoring::ProcessedEndpoints,
};
......@@ -113,7 +110,7 @@ impl KvRouter {
}
// [TODO] indexer needs to take 'lora_id' as parameter
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<String> {
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
// Extracting part of the code in KvRouter::generate() for only
// the decision making part, routing is done by the caller
let isl_tokens = token_ids.len();
......@@ -122,25 +119,8 @@ impl KvRouter {
.find_matches_for_request(token_ids.as_slice())
.await?;
log::debug!("KV router overlap_scores: {:?}", overlap_scores);
// [FIXME] Python binding results in "endpoint subscriber shutdown" error,
// need to investigate whether it happens in pure rust as well and then
// root cause it. Before that, not doing intelligent scheduling for rapid
// development..
// [FIXME] also need to fix that scheduler returns worker subject which is not
// the same as worker id (uuid). Seems like it adds additional annotation on top of uuid.
// Need to double check
// 'worker_subject' should be the same as worker id used for direct routing
// let worker_subject = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
let mut selected_worker_subject = Option::<String>::None;
for (worker_subject, overlap_score) in &overlap_scores.scores {
if ((*overlap_score as usize * KV_BLOCK_SIZE) as f64 / isl_tokens as f64) >= 0.5 {
selected_worker_subject = Some(worker_subject.to_string());
}
}
match selected_worker_subject {
None => Err(anyhow::anyhow!("No worker found")),
Some(worker_subject) => Ok(worker_subject),
}
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id)
}
}
......@@ -167,7 +147,7 @@ async fn collect_endpoints(
.unwrap();
// [FIXME] Endpoint is parsed from nats stats handler which may not include 'data' field
// if the service hasn't registered the handler.
// if the service hasn't registered the handler. Need to be tolerant to this.
// Another option is to make sure the router is configured properly that
// it listens to the right subject (where other publisher has stats).
let services: Vec<Service> = values
......
......@@ -79,7 +79,7 @@ pub enum KvRouterError {
}
/// Identifier of a LLM worker which emits events to the router.
pub type WorkerId = uuid::Uuid;
pub type WorkerId = i64;
/// A shared reference to a [`RadixBlock`].
type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
......
......@@ -13,20 +13,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::kv_router::{indexer::RouterEvent, protocols::KvCacheEvent, KV_EVENT_SUBJECT};
use crate::kv_router::{indexer::RouterEvent, protocols::*, KV_EVENT_SUBJECT};
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing as log;
use triton_distributed_runtime::{component::Component, DistributedRuntime, Result};
use uuid::Uuid;
pub struct KvPublisher {
pub struct KvEventPublisher {
tx: mpsc::UnboundedSender<KvCacheEvent>,
}
impl KvPublisher {
pub fn new(drt: DistributedRuntime, backend: Component, worker_id: Uuid) -> Result<Self> {
impl KvEventPublisher {
pub fn new(drt: DistributedRuntime, backend: Component, worker_id: i64) -> Result<Self> {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let p = KvPublisher { tx };
let p = KvEventPublisher { tx };
start_publish_task(drt, backend, worker_id, rx);
Ok(p)
......@@ -41,12 +41,10 @@ impl KvPublisher {
fn start_publish_task(
drt: DistributedRuntime,
backend: Component,
worker_id: Uuid,
worker_id: i64,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
) {
let client = drt.nats_client().client().clone();
// [FIXME] service name is for metrics polling?
// let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
log::info!("Publishing KV Events to subject: {}", kv_subject);
......@@ -61,3 +59,37 @@ fn start_publish_task(
}
});
}
pub struct KvMetricsPublisher {
tx: tokio::sync::watch::Sender<Arc<ForwardPassMetrics>>,
rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>,
}
impl KvMetricsPublisher {
pub fn new() -> Result<Self> {
let (tx, rx) = tokio::sync::watch::channel(Arc::new(ForwardPassMetrics::default()));
Ok(KvMetricsPublisher { tx, rx })
}
pub fn publish(
&self,
metrics: Arc<ForwardPassMetrics>,
) -> Result<(), tokio::sync::watch::error::SendError<Arc<ForwardPassMetrics>>> {
log::debug!("Publish metrics: {:?}", metrics);
self.tx.send(metrics)
}
pub async fn create_service(&self, component: Component) -> Result<()> {
let mut metrics_rx = self.rx.clone();
let _ = component
.service_builder()
.stats_handler(Some(Box::new(move |name, stats| {
log::debug!("[IN worker?] Stats for service {}: {:?}", name, stats);
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})))
.create()
.await?;
Ok(())
}
}
......@@ -17,8 +17,6 @@ use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::cmp::min;
use uuid::Uuid;
use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE};
use crate::kv_router::scoring::ProcessedEndpoints;
......@@ -44,16 +42,17 @@ pub struct Endpoint {
}
impl Endpoint {
pub fn worker_id(&self) -> Uuid {
Uuid::parse_str(
pub fn worker_id(&self) -> i64 {
i64::from_str_radix(
self.subject
.split(".")
.split("-")
.last()
.expect("invalid subject")
.to_string()
.as_str(),
16,
)
.expect("invalid uuid")
.expect("invalid worker id")
}
}
......@@ -69,11 +68,11 @@ pub struct Service {
pub struct SchedulingRequest {
isl_tokens: usize,
overlap: OverlapScores,
resp_tx: tokio::sync::oneshot::Sender<String>,
resp_tx: tokio::sync::oneshot::Sender<i64>,
}
impl SchedulingRequest {
pub fn respond(self, worker_id: String) {
pub fn respond(self, worker_id: i64) {
if self.resp_tx.send(worker_id).is_err() {
tracing::trace!("failed to send response to requestor");
}
......@@ -174,7 +173,7 @@ impl KvScheduler {
&self,
overlap: OverlapScores,
isl_tokens: usize,
) -> Result<String, KvSchedulerError> {
) -> Result<i64, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
isl_tokens,
......@@ -199,7 +198,7 @@ impl KvScheduler {
pub fn select_worker(
workers: &mut ProcessedEndpoints,
request: &SchedulingRequest,
) -> Result<String, KvSchedulerError> {
) -> Result<i64, KvSchedulerError> {
// balance mode prioritizes balancing load across workers
let balance_threshold: f64 = 0.1;
let balance_mode = workers.load_std > balance_threshold * workers.load_avg;
......@@ -227,6 +226,7 @@ pub fn select_worker(
let kv_load_ratio = w.data.kv_active_blocks as f64 / w.data.kv_total_blocks as f64;
let load_deviation = kv_load_ratio - workers.load_avg;
// [FIXME] multiple endpoints of the same worker cause out of bound error
let worker_id = workers.worker_ids[i];
let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x);
let overlap_score = overlap_score as usize * KV_BLOCK_SIZE;
......@@ -267,10 +267,10 @@ pub fn select_worker(
Some(i) => {
tracing::info!(
"selected worker: {}; cost: {}",
workers.endpoints[i].subject,
workers.endpoints[i].worker_id(),
best_cost
);
Ok(workers.endpoints[i].subject.clone())
Ok(workers.endpoints[i].worker_id())
}
None => {
tracing::debug!("all workers busy");
......
......@@ -18,12 +18,11 @@
use std::collections::HashSet;
use crate::kv_router::scheduler::Endpoint;
use uuid::Uuid;
#[derive(Debug, Default)]
pub struct ProcessedEndpoints {
pub endpoints: Vec<Endpoint>,
pub worker_ids: Vec<Uuid>,
pub worker_ids: Vec<i64>,
pub load_avg: f64,
pub load_std: f64,
}
......@@ -43,8 +42,8 @@ impl ProcessedEndpoints {
/ load_values.len() as f64;
let load_std = variance.sqrt();
let worker_ids: HashSet<Uuid> = endpoints.iter().map(|x| x.worker_id()).collect();
let worker_ids: Vec<Uuid> = worker_ids.into_iter().collect();
let worker_ids: HashSet<i64> = endpoints.iter().map(|x| x.worker_id()).collect();
let worker_ids: Vec<i64> = worker_ids.into_iter().collect();
ProcessedEndpoints {
endpoints,
......
......@@ -34,7 +34,6 @@ use std::time::SystemTime;
use super::TokenIdType;
pub mod kv_routing;
pub mod llm_backend;
pub mod postprocessor;
pub mod preprocessor;
......
......@@ -74,9 +74,6 @@ addopts = [
"--mypy",
"--ignore-glob=*model.py",
# FIXME: Get relative/generic blob paths to work here
# Ignore rust<->python bindings until python package is built/installed in environment
"--ignore-glob=/workspace/python-wheel/python/triton_distributed_rs/*.py",
"--ignore-glob=/workspace/python-wheel/python/triton_distributed_rs/*.pyi",
]
xfail_strict = true
log_cli_level = "INFO"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment