Commit 861c5098 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files
parent eb022ec9
......@@ -180,7 +180,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>]
Example:
```bash
# Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model
bash /workspace/examples/python_rs/llm/vllm/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B
bash /workspace/examples/python_rs/llm/vllm/scripts/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# List tmux sessions
tmux ls
......@@ -252,7 +252,7 @@ llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B triton-init
```
```bash
curl localhost:9992/v1/chat/completions -H "Content-Type: application/json" -d '{
curl localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [
{
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uvloop
from common.protocol import Request, Response
from vllm.logger import logger as vllm_logger
from triton_distributed.llm import KvRouter
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
router,
workers_client,
):
self.router = router
self.workers_client = workers_client
@triton_endpoint(Request, Response)
async def generate(self, request):
lora_id = 0
worker_id = None
tokens = [3] * 64
try:
worker_id = await self.router.schedule(tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"got exception of type {type(e)}: {e}")
worker_id = None
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
if worker_id is None:
vllm_logger.info("randomly select worker")
engine_generator = await self.workers_client.random(
request.model_dump_json()
)
else:
vllm_logger.info(f"directly select worker: {worker_id}")
engine_generator = await self.workers_client.direct(
request.model_dump_json(), worker_id
)
async for resp in engine_generator:
resp = resp.data() if hasattr(resp, "data") else resp
yield resp
@triton_endpoint(Request, Response)
async def mock_generate(self, request):
print(f"Received request: {request}")
yield "Hello, World!"
ROUTE_SELF = True
@triton_worker()
async def worker(runtime: DistributedRuntime):
workers_client = (
await runtime.namespace("triton-init")
.component("vllm")
.endpoint("generate")
.client()
)
vllm_logger.info(
f"Have number of workers ({len(workers_client.endpoint_ids())}) are ready:\n"
+ "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
)
# [TODO] Collect endpoint implementation expects services to provide
# ForwardPassMetrics as part of stats handling and it will panic if
# otherwise. This needs to be fixed so that non-providing endpoints will
# simply be ignored, but before that, we will make sure that the services
# of the same namespace::component are created via KvMetricsPublisher,
# if it is also used to create endpoints.
kv_listener = runtime.namespace("triton-init").component("vllm")
await kv_listener.create_service()
router = KvRouter(runtime, kv_listener)
# i.e. below will cause panic
# endpoint = kv_listener.endpoint("generate")
# await endpoint.serve_endpoint(
# Router(router, workers_client).mock_generate
# )
router_component = runtime.namespace("triton-init").component("frontend")
await router_component.create_service()
endpoint = router_component.endpoint("generate")
await endpoint.serve_endpoint(Router(router, workers_client).generate)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import ctypes
from ctypes import c_char_p, c_int64, c_uint32
import uvloop
from common.protocol import Request, Response
from vllm.logger import logger as vllm_logger
from triton_distributed.llm import KvMetricsPublisher
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
class TritonResult:
OK = 0
ERR = 1
class MockEngine:
"""
Request handler for the generate endpoint
"""
def __init__(self, metrics_publisher, worker_id):
self.worker_id = worker_id
# KV events
self.lib = ctypes.CDLL("/opt/triton/llm_binding/lib/libtriton_llm_capi.so")
self.lib.triton_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
self.lib.triton_llm_init.restype = c_uint32
result = self.lib.triton_llm_init(
"triton-init".encode(), "vllm".encode(), worker_id
)
if result == TritonResult.OK:
vllm_logger.info(
"KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
)
else:
vllm_logger.info("KVCacheEventManager initialization failed!")
self.lib.triton_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids
ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
ctypes.POINTER(ctypes.c_uint64), # parent_hash
ctypes.c_uint64, # lora_id
]
self.lib.triton_kv_event_publish_stored.restype = (
ctypes.c_uint32
) # triton_llm_result_t
self.lib.triton_kv_event_publish_removed.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
]
self.lib.triton_kv_event_publish_removed.restype = (
ctypes.c_uint32
) # triton_llm_result_t
# KV metrics
self.metrics_publisher = metrics_publisher
self.request_active_slots = 0
self.request_total_slots = 4
self.kv_active_block = 0
self.kv_total_blocks = 4
# [NOTE] Now that the component must has proper metrics reported
# to be properly selected by the router
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
self.event_id_counter = 0
self.tokens = [3] * 64
@triton_endpoint(Request, Response)
async def generate(self, request):
print(f"Received request: {request}")
self.request_active_slots = min(
self.request_active_slots + 1, self.request_total_slots
)
self.kv_active_block = min(self.kv_active_block + 1, self.kv_total_blocks)
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
self.store_event()
yield "Hello, World!"
def store_event(self):
parent_hash = (
(ctypes.c_uint64 * 1)(self.event_id_counter)
if self.event_id_counter > 0
else None
)
result = self.lib.triton_kv_event_publish_stored(
self.event_id_counter, # uint64_t event_id
(ctypes.c_uint32 * len(self.tokens))(
*self.tokens
), # const uint32_t *token_ids
(ctypes.c_size_t * 1)(
len(self.tokens)
), # const uintptr_t *num_block_tokens
(ctypes.c_uint64 * 1)(self.event_id_counter), # const uint64_t *block_ids
1, # uintptr_t num_blocks
parent_hash, # const uint64_t *parent_hash
0, # uint64_t lora_id
)
self.event_id_counter += 1
if result == TritonResult.OK:
vllm_logger.debug(f"Store - Published KV Event: {self.event_id_counter}")
else:
vllm_logger.debug(
f"Store - Failed to Publish KV Event: {self.event_id_counter}"
)
async def cooldown(self):
while True:
await asyncio.sleep(5)
self.request_active_slots = max(0, self.request_active_slots - 1)
self.kv_active_block = max(0, self.kv_active_block - 1)
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
@triton_worker()
async def worker(runtime: DistributedRuntime):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("triton-init").component("vllm")
metrics_publisher = KvMetricsPublisher()
await metrics_publisher.create_service(component)
endpoint = component.endpoint("generate")
engine = MockEngine(metrics_publisher, endpoint.lease_id())
await asyncio.gather(
engine.cooldown(),
endpoint.serve_endpoint(engine.generate),
)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
......@@ -31,8 +31,8 @@ from vllm.logger import logger as vllm_logger
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from triton_distributed._core import Client
from triton_distributed.runtime import (
Client,
DistributedRuntime,
triton_endpoint,
triton_worker,
......@@ -126,7 +126,7 @@ class Processor(ProcessMixIn):
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json(),
uuid.UUID(worker_id).int,
int(worker_id),
)
output = self.generate_responses(engine_generator)
......
......@@ -58,20 +58,21 @@ class Router:
@triton_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = ""
worker_id = None
if self.routing_strategy == RoutingStrategy.PREFIX:
try:
worker_id = await self.router.schedule(request.tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"{e}")
if "No worker found" in str(e):
worker_id = ""
else:
worker_id = None
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
yield worker_id
yield str(worker_id)
else:
# TODO: Do we implement round_robin and random here?
......@@ -113,8 +114,7 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
+ "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
)
# TODO Router is a fixed namespace separate from the others
kv_listener = runtime.namespace("router").component(args.model_name)
kv_listener = runtime.namespace("triton-init").component("vllm")
await kv_listener.create_service()
router_component = runtime.namespace("triton-init").component("router")
......
......@@ -15,7 +15,6 @@
import asyncio
import os
import uuid
from typing import AsyncIterator
import uvloop
......@@ -26,6 +25,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import logger as vllm_logger
from vllm.sampling_params import RequestOutputKind
from triton_distributed.llm import KvMetricsPublisher
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
......@@ -40,10 +40,18 @@ class VllmEngine(BaseVllmEngine):
vLLM Inference Engine
"""
def __init__(self, engine_args: AsyncEngineArgs):
def __init__(
self, engine_args: AsyncEngineArgs, metrics_publisher: KvMetricsPublisher
):
self.metrics_publisher = metrics_publisher
self.engine_args = engine_args
super().__init__(engine_args)
async def initialize(self):
await super().initialize()
assert self.engine_client is not None, "engine_client was not initialized"
self.engine_client.set_metrics_publisher(self.metrics_publisher)
@triton_endpoint(vLLMGenerateRequest, MyRequestOutput)
async def generate(self, request) -> AsyncIterator:
assert (
......@@ -74,21 +82,32 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Serve the triton-init.vllm.generate endpoint.
"""
metrics_publisher = KvMetricsPublisher()
worker_component = runtime.namespace("triton-init").component("vllm")
await worker_component.create_service()
await metrics_publisher.create_service(worker_component)
worker_endpoint = worker_component.endpoint("generate")
# KV Publisher and Aggregator requires a UUID (str)
# KV Router requires a lease_id (int)
# This allows us to please both, until they are unified
# If VLLM_WORKER_ID is not set, KV Routing will fail
VLLM_WORKER_ID = uuid.UUID(int=worker_endpoint.lease_id())
VLLM_WORKER_ID = worker_endpoint.lease_id()
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
vllm_engine = VllmEngine(engine_args)
VLLM_KV_NAMESPACE = "triton-init"
os.environ["VLLM_KV_NAMESPACE"] = str(VLLM_KV_NAMESPACE)
VLLM_KV_COMPONENT = "vllm"
os.environ["VLLM_KV_COMPONENT"] = str(VLLM_KV_COMPONENT)
vllm_engine = VllmEngine(engine_args, metrics_publisher)
await vllm_engine.initialize()
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
metrics_publisher.publish(
0,
1024,
0,
1024,
)
await worker_endpoint.serve_endpoint(vllm_engine.generate)
......
......@@ -22,13 +22,13 @@ use tracing as log;
use uuid::Uuid;
use triton_distributed_llm::kv_router::{
indexer::compute_block_hash_for_seq, protocols::*, publisher::KvPublisher,
indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher,
};
use triton_distributed_runtime::{DistributedRuntime, Worker};
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
static KV_PUB: OnceCell<KvPublisher> = OnceCell::new();
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
fn initialize_tracing() {
// Sets up RUST_LOG environment variable for logging while KV Publishing
......@@ -49,11 +49,12 @@ pub enum TritonLlmResult {
}
/// # Safety
/// the model_name_c_str and worker_id_c_str are passed as pointers to C strings
/// the namespace_c_str and component_c_str are passed as pointers to C strings
#[no_mangle]
pub unsafe extern "C" fn triton_llm_init(
model_name_c_str: *const c_char,
worker_id_c_str: *const c_char,
namespace_c_str: *const c_char,
component_c_str: *const c_char,
worker_id: i64,
) -> TritonLlmResult {
initialize_tracing();
let wk = match WK.get_or_try_init(Worker::from_settings) {
......@@ -78,7 +79,7 @@ pub unsafe extern "C" fn triton_llm_init(
}
}
});
let model_name = match unsafe { CStr::from_ptr(model_name_c_str) }.to_str() {
let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
Ok(s) => s.to_string(),
Err(e) => {
eprintln!("Failed to convert C string to Rust string: {:?}", e);
......@@ -86,24 +87,17 @@ pub unsafe extern "C" fn triton_llm_init(
}
};
let worker_id_str = match unsafe { CStr::from_ptr(worker_id_c_str) }.to_str() {
Ok(s) => s,
let component = match unsafe { CStr::from_ptr(component_c_str) }.to_str() {
Ok(s) => s.to_string(),
Err(e) => {
eprintln!("Failed to convert C string to Rust string: {:?}", e);
return TritonLlmResult::ERR;
}
};
let worker_id_uuid = match Uuid::parse_str(worker_id_str) {
Ok(uuid) => uuid,
Err(e) => {
eprintln!("Failed to parse worker_id as UUID: {:?}", e);
return TritonLlmResult::ERR;
}
};
match result {
Ok(_) => match KV_PUB
.get_or_try_init(move || triton_create_kv_publisher(model_name, worker_id_uuid))
.get_or_try_init(move || triton_create_kv_publisher(namespace, component, worker_id))
{
Ok(_) => TritonLlmResult::OK,
Err(e) => {
......@@ -143,17 +137,18 @@ pub extern "C" fn triton_llm_load_publisher_create() -> TritonLlmResult {
// c++ executor api
fn triton_create_kv_publisher(
model_name: String,
worker_id: Uuid,
) -> Result<KvPublisher, anyhow::Error> {
log::info!("Creating KV Publisher for model: {}", model_name);
namespace: String,
component: String,
worker_id: i64,
) -> Result<KvEventPublisher, anyhow::Error> {
log::info!("Creating KV Publisher for model: {}", component);
match DRT
.get()
.ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
{
Ok(drt) => {
let backend = drt.namespace("router")?.component(model_name)?;
KvPublisher::new(drt.clone(), backend, worker_id)
let backend = drt.namespace(namespace)?.component(component)?;
KvEventPublisher::new(drt.clone(), backend, worker_id)
}
Err(e) => Err(e),
}
......
......@@ -64,6 +64,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Client>()?;
m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::kv::KvMetricsPublisher>()?;
engine::add_to_module(m)?;
......
......@@ -23,6 +23,7 @@ pub(crate) struct KvRouter {
#[pymethods]
impl KvRouter {
#[new]
// [FXIME] 'drt' can be obtained from 'component'
fn new(drt: DistributedRuntime, component: Component) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
......@@ -44,11 +45,64 @@ impl KvRouter {
) -> PyResult<Bound<'p, PyAny>> {
let router = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let uuid = router
let worker_id = router
.schedule(&token_ids, lora_id)
.await
.map_err(to_pyerr)?;
Ok(uuid.to_string())
Ok(worker_id)
})
}
}
#[pyclass]
pub(crate) struct KvMetricsPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvMetricsPublisher>,
}
#[pymethods]
impl KvMetricsPublisher {
#[new]
fn new() -> PyResult<Self> {
let inner = llm_rs::kv_router::publisher::KvMetricsPublisher::new().map_err(to_pyerr)?;
Ok(Self {
inner: inner.into(),
})
}
fn create_service<'p>(
&self,
py: Python<'p>,
component: Component,
) -> PyResult<Bound<'p, PyAny>> {
let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let _ = rs_publisher
.create_service(rs_component)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
fn publish<'p>(
&self,
py: Python<'p>,
request_active_slots: u64,
request_total_slots: u64,
kv_active_blocks: u64,
kv_total_blocks: u64,
) -> PyResult<()> {
self.inner
.publish(
llm_rs::kv_router::protocols::ForwardPassMetrics {
request_active_slots,
request_total_slots,
kv_active_blocks,
kv_total_blocks,
}
.into(),
)
.map_err(to_pyerr)
}
}
......@@ -128,7 +128,7 @@ class Client:
class KvRouter:
"""
The runtime object for a distributed NOVA applications
A router will determine which worker should handle a given request.
"""
...
......@@ -138,9 +138,36 @@ class KvRouter:
Create a `KvRouter` object that is associated with the `component`
"""
def schedule(self, token_ids: List[int], lora_id: int) -> str:
def schedule(self, token_ids: List[int], lora_id: int) -> int:
"""
Return the worker id that should handle the given token ids,
exception will be raised if there is no worker available.
"""
...
class KvMetricsPublisher:
"""
A metrics publisher will provide KV metrics to the router.
"""
...
def __init__(self) -> None:
"""
Create a `KvMetricsPublisher` object
"""
def create_service(self, component: Component) -> None:
"""
Similar to Component.create_service, but only service created through
this method will interact with KV router of the same component.
"""
def publish(self, request_active_slots: int,
request_total_slots: int,
kv_active_blocks: int,
kv_total_blocks: int) -> None:
"""
Update the KV metrics being reported.
"""
...
......@@ -13,4 +13,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from triton_distributed._core import KvMetricsPublisher as KvMetricsPublisher
from triton_distributed._core import KvRouter as KvRouter
......@@ -20,7 +20,8 @@ from typing import Any, AsyncGenerator, Callable, Type
from pydantic import BaseModel, ValidationError
from triton_distributed._core import DistributedRuntime
from triton_distributed._core import Client as Client
from triton_distributed._core import DistributedRuntime as DistributedRuntime
def triton_worker():
......
......@@ -23,14 +23,11 @@ use triton_distributed_runtime::{component::Component, DistributedRuntime};
pub mod indexer;
pub mod protocols;
pub mod publisher;
// [WIP] enable service_builder() through worker for metrics reporting
// pub mod worker;
mod scheduler;
mod scoring;
use crate::kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
protocols::KV_BLOCK_SIZE,
scheduler::{Endpoint, KvScheduler, Service},
scoring::ProcessedEndpoints,
};
......@@ -113,7 +110,7 @@ impl KvRouter {
}
// [TODO] indexer needs to take 'lora_id' as parameter
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<String> {
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
// Extracting part of the code in KvRouter::generate() for only
// the decision making part, routing is done by the caller
let isl_tokens = token_ids.len();
......@@ -122,25 +119,8 @@ impl KvRouter {
.find_matches_for_request(token_ids.as_slice())
.await?;
log::debug!("KV router overlap_scores: {:?}", overlap_scores);
// [FIXME] Python binding results in "endpoint subscriber shutdown" error,
// need to investigate whether it happens in pure rust as well and then
// root cause it. Before that, not doing intelligent scheduling for rapid
// development..
// [FIXME] also need to fix that scheduler returns worker subject which is not
// the same as worker id (uuid). Seems like it adds additional annotation on top of uuid.
// Need to double check
// 'worker_subject' should be the same as worker id used for direct routing
// let worker_subject = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
let mut selected_worker_subject = Option::<String>::None;
for (worker_subject, overlap_score) in &overlap_scores.scores {
if ((*overlap_score as usize * KV_BLOCK_SIZE) as f64 / isl_tokens as f64) >= 0.5 {
selected_worker_subject = Some(worker_subject.to_string());
}
}
match selected_worker_subject {
None => Err(anyhow::anyhow!("No worker found")),
Some(worker_subject) => Ok(worker_subject),
}
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id)
}
}
......@@ -167,7 +147,7 @@ async fn collect_endpoints(
.unwrap();
// [FIXME] Endpoint is parsed from nats stats handler which may not include 'data' field
// if the service hasn't registered the handler.
// if the service hasn't registered the handler. Need to be tolerant to this.
// Another option is to make sure the router is configured properly that
// it listens to the right subject (where other publisher has stats).
let services: Vec<Service> = values
......
......@@ -79,7 +79,7 @@ pub enum KvRouterError {
}
/// Identifier of a LLM worker which emits events to the router.
pub type WorkerId = uuid::Uuid;
pub type WorkerId = i64;
/// A shared reference to a [`RadixBlock`].
type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
......
......@@ -13,20 +13,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::kv_router::{indexer::RouterEvent, protocols::KvCacheEvent, KV_EVENT_SUBJECT};
use crate::kv_router::{indexer::RouterEvent, protocols::*, KV_EVENT_SUBJECT};
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing as log;
use triton_distributed_runtime::{component::Component, DistributedRuntime, Result};
use uuid::Uuid;
pub struct KvPublisher {
pub struct KvEventPublisher {
tx: mpsc::UnboundedSender<KvCacheEvent>,
}
impl KvPublisher {
pub fn new(drt: DistributedRuntime, backend: Component, worker_id: Uuid) -> Result<Self> {
impl KvEventPublisher {
pub fn new(drt: DistributedRuntime, backend: Component, worker_id: i64) -> Result<Self> {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let p = KvPublisher { tx };
let p = KvEventPublisher { tx };
start_publish_task(drt, backend, worker_id, rx);
Ok(p)
......@@ -41,12 +41,10 @@ impl KvPublisher {
fn start_publish_task(
drt: DistributedRuntime,
backend: Component,
worker_id: Uuid,
worker_id: i64,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
) {
let client = drt.nats_client().client().clone();
// [FIXME] service name is for metrics polling?
// let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
log::info!("Publishing KV Events to subject: {}", kv_subject);
......@@ -61,3 +59,37 @@ fn start_publish_task(
}
});
}
pub struct KvMetricsPublisher {
tx: tokio::sync::watch::Sender<Arc<ForwardPassMetrics>>,
rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>,
}
impl KvMetricsPublisher {
pub fn new() -> Result<Self> {
let (tx, rx) = tokio::sync::watch::channel(Arc::new(ForwardPassMetrics::default()));
Ok(KvMetricsPublisher { tx, rx })
}
pub fn publish(
&self,
metrics: Arc<ForwardPassMetrics>,
) -> Result<(), tokio::sync::watch::error::SendError<Arc<ForwardPassMetrics>>> {
log::debug!("Publish metrics: {:?}", metrics);
self.tx.send(metrics)
}
pub async fn create_service(&self, component: Component) -> Result<()> {
let mut metrics_rx = self.rx.clone();
let _ = component
.service_builder()
.stats_handler(Some(Box::new(move |name, stats| {
log::debug!("[IN worker?] Stats for service {}: {:?}", name, stats);
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})))
.create()
.await?;
Ok(())
}
}
......@@ -17,8 +17,6 @@ use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::cmp::min;
use uuid::Uuid;
use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE};
use crate::kv_router::scoring::ProcessedEndpoints;
......@@ -44,16 +42,17 @@ pub struct Endpoint {
}
impl Endpoint {
pub fn worker_id(&self) -> Uuid {
Uuid::parse_str(
pub fn worker_id(&self) -> i64 {
i64::from_str_radix(
self.subject
.split(".")
.split("-")
.last()
.expect("invalid subject")
.to_string()
.as_str(),
16,
)
.expect("invalid uuid")
.expect("invalid worker id")
}
}
......@@ -69,11 +68,11 @@ pub struct Service {
pub struct SchedulingRequest {
isl_tokens: usize,
overlap: OverlapScores,
resp_tx: tokio::sync::oneshot::Sender<String>,
resp_tx: tokio::sync::oneshot::Sender<i64>,
}
impl SchedulingRequest {
pub fn respond(self, worker_id: String) {
pub fn respond(self, worker_id: i64) {
if self.resp_tx.send(worker_id).is_err() {
tracing::trace!("failed to send response to requestor");
}
......@@ -174,7 +173,7 @@ impl KvScheduler {
&self,
overlap: OverlapScores,
isl_tokens: usize,
) -> Result<String, KvSchedulerError> {
) -> Result<i64, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
isl_tokens,
......@@ -199,7 +198,7 @@ impl KvScheduler {
pub fn select_worker(
workers: &mut ProcessedEndpoints,
request: &SchedulingRequest,
) -> Result<String, KvSchedulerError> {
) -> Result<i64, KvSchedulerError> {
// balance mode prioritizes balancing load across workers
let balance_threshold: f64 = 0.1;
let balance_mode = workers.load_std > balance_threshold * workers.load_avg;
......@@ -227,6 +226,7 @@ pub fn select_worker(
let kv_load_ratio = w.data.kv_active_blocks as f64 / w.data.kv_total_blocks as f64;
let load_deviation = kv_load_ratio - workers.load_avg;
// [FIXME] multiple endpoints of the same worker cause out of bound error
let worker_id = workers.worker_ids[i];
let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x);
let overlap_score = overlap_score as usize * KV_BLOCK_SIZE;
......@@ -267,10 +267,10 @@ pub fn select_worker(
Some(i) => {
tracing::info!(
"selected worker: {}; cost: {}",
workers.endpoints[i].subject,
workers.endpoints[i].worker_id(),
best_cost
);
Ok(workers.endpoints[i].subject.clone())
Ok(workers.endpoints[i].worker_id())
}
None => {
tracing::debug!("all workers busy");
......
......@@ -18,12 +18,11 @@
use std::collections::HashSet;
use crate::kv_router::scheduler::Endpoint;
use uuid::Uuid;
#[derive(Debug, Default)]
pub struct ProcessedEndpoints {
pub endpoints: Vec<Endpoint>,
pub worker_ids: Vec<Uuid>,
pub worker_ids: Vec<i64>,
pub load_avg: f64,
pub load_std: f64,
}
......@@ -43,8 +42,8 @@ impl ProcessedEndpoints {
/ load_values.len() as f64;
let load_std = variance.sqrt();
let worker_ids: HashSet<Uuid> = endpoints.iter().map(|x| x.worker_id()).collect();
let worker_ids: Vec<Uuid> = worker_ids.into_iter().collect();
let worker_ids: HashSet<i64> = endpoints.iter().map(|x| x.worker_id()).collect();
let worker_ids: Vec<i64> = worker_ids.into_iter().collect();
ProcessedEndpoints {
endpoints,
......
......@@ -34,7 +34,6 @@ use std::time::SystemTime;
use super::TokenIdType;
pub mod kv_routing;
pub mod llm_backend;
pub mod postprocessor;
pub mod preprocessor;
......
......@@ -74,9 +74,6 @@ addopts = [
"--mypy",
"--ignore-glob=*model.py",
# FIXME: Get relative/generic blob paths to work here
# Ignore rust<->python bindings until python package is built/installed in environment
"--ignore-glob=/workspace/python-wheel/python/triton_distributed_rs/*.py",
"--ignore-glob=/workspace/python-wheel/python/triton_distributed_rs/*.pyi",
]
xfail_strict = true
log_cli_level = "INFO"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment