"pcdet/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "0642cf06d0fd84f50cc4c6c01ea28edbc72ea810"
Commit 861c5098 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files
parent eb022ec9
...@@ -180,7 +180,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>] ...@@ -180,7 +180,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>]
Example: Example:
```bash ```bash
# Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model # Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model
bash /workspace/examples/python_rs/llm/vllm/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B bash /workspace/examples/python_rs/llm/vllm/scripts/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# List tmux sessions # List tmux sessions
tmux ls tmux ls
...@@ -252,7 +252,7 @@ llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B triton-init ...@@ -252,7 +252,7 @@ llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B triton-init
``` ```
```bash ```bash
curl localhost:9992/v1/chat/completions -H "Content-Type: application/json" -d '{ curl localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [ "messages": [
{ {
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uvloop
from common.protocol import Request, Response
from vllm.logger import logger as vllm_logger
from triton_distributed.llm import KvRouter
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
router,
workers_client,
):
self.router = router
self.workers_client = workers_client
@triton_endpoint(Request, Response)
async def generate(self, request):
lora_id = 0
worker_id = None
tokens = [3] * 64
try:
worker_id = await self.router.schedule(tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"got exception of type {type(e)}: {e}")
worker_id = None
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
if worker_id is None:
vllm_logger.info("randomly select worker")
engine_generator = await self.workers_client.random(
request.model_dump_json()
)
else:
vllm_logger.info(f"directly select worker: {worker_id}")
engine_generator = await self.workers_client.direct(
request.model_dump_json(), worker_id
)
async for resp in engine_generator:
resp = resp.data() if hasattr(resp, "data") else resp
yield resp
@triton_endpoint(Request, Response)
async def mock_generate(self, request):
print(f"Received request: {request}")
yield "Hello, World!"
ROUTE_SELF = True
@triton_worker()
async def worker(runtime: DistributedRuntime):
workers_client = (
await runtime.namespace("triton-init")
.component("vllm")
.endpoint("generate")
.client()
)
vllm_logger.info(
f"Have number of workers ({len(workers_client.endpoint_ids())}) are ready:\n"
+ "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
)
# [TODO] Collect endpoint implementation expects services to provide
# ForwardPassMetrics as part of stats handling and it will panic if
# otherwise. This needs to be fixed so that non-providing endpoints will
# simply be ignored, but before that, we will make sure that the services
# of the same namespace::component are created via KvMetricsPublisher,
# if it is also used to create endpoints.
kv_listener = runtime.namespace("triton-init").component("vllm")
await kv_listener.create_service()
router = KvRouter(runtime, kv_listener)
# i.e. below will cause panic
# endpoint = kv_listener.endpoint("generate")
# await endpoint.serve_endpoint(
# Router(router, workers_client).mock_generate
# )
router_component = runtime.namespace("triton-init").component("frontend")
await router_component.create_service()
endpoint = router_component.endpoint("generate")
await endpoint.serve_endpoint(Router(router, workers_client).generate)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import ctypes
from ctypes import c_char_p, c_int64, c_uint32
import uvloop
from common.protocol import Request, Response
from vllm.logger import logger as vllm_logger
from triton_distributed.llm import KvMetricsPublisher
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
class TritonResult:
OK = 0
ERR = 1
class MockEngine:
"""
Request handler for the generate endpoint
"""
def __init__(self, metrics_publisher, worker_id):
self.worker_id = worker_id
# KV events
self.lib = ctypes.CDLL("/opt/triton/llm_binding/lib/libtriton_llm_capi.so")
self.lib.triton_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
self.lib.triton_llm_init.restype = c_uint32
result = self.lib.triton_llm_init(
"triton-init".encode(), "vllm".encode(), worker_id
)
if result == TritonResult.OK:
vllm_logger.info(
"KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
)
else:
vllm_logger.info("KVCacheEventManager initialization failed!")
self.lib.triton_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids
ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
ctypes.POINTER(ctypes.c_uint64), # parent_hash
ctypes.c_uint64, # lora_id
]
self.lib.triton_kv_event_publish_stored.restype = (
ctypes.c_uint32
) # triton_llm_result_t
self.lib.triton_kv_event_publish_removed.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
]
self.lib.triton_kv_event_publish_removed.restype = (
ctypes.c_uint32
) # triton_llm_result_t
# KV metrics
self.metrics_publisher = metrics_publisher
self.request_active_slots = 0
self.request_total_slots = 4
self.kv_active_block = 0
self.kv_total_blocks = 4
# [NOTE] Now that the component must has proper metrics reported
# to be properly selected by the router
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
self.event_id_counter = 0
self.tokens = [3] * 64
@triton_endpoint(Request, Response)
async def generate(self, request):
print(f"Received request: {request}")
self.request_active_slots = min(
self.request_active_slots + 1, self.request_total_slots
)
self.kv_active_block = min(self.kv_active_block + 1, self.kv_total_blocks)
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
self.store_event()
yield "Hello, World!"
def store_event(self):
parent_hash = (
(ctypes.c_uint64 * 1)(self.event_id_counter)
if self.event_id_counter > 0
else None
)
result = self.lib.triton_kv_event_publish_stored(
self.event_id_counter, # uint64_t event_id
(ctypes.c_uint32 * len(self.tokens))(
*self.tokens
), # const uint32_t *token_ids
(ctypes.c_size_t * 1)(
len(self.tokens)
), # const uintptr_t *num_block_tokens
(ctypes.c_uint64 * 1)(self.event_id_counter), # const uint64_t *block_ids
1, # uintptr_t num_blocks
parent_hash, # const uint64_t *parent_hash
0, # uint64_t lora_id
)
self.event_id_counter += 1
if result == TritonResult.OK:
vllm_logger.debug(f"Store - Published KV Event: {self.event_id_counter}")
else:
vllm_logger.debug(
f"Store - Failed to Publish KV Event: {self.event_id_counter}"
)
async def cooldown(self):
while True:
await asyncio.sleep(5)
self.request_active_slots = max(0, self.request_active_slots - 1)
self.kv_active_block = max(0, self.kv_active_block - 1)
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
@triton_worker()
async def worker(runtime: DistributedRuntime):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("triton-init").component("vllm")
metrics_publisher = KvMetricsPublisher()
await metrics_publisher.create_service(component)
endpoint = component.endpoint("generate")
engine = MockEngine(metrics_publisher, endpoint.lease_id())
await asyncio.gather(
engine.cooldown(),
endpoint.serve_endpoint(engine.generate),
)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
...@@ -31,8 +31,8 @@ from vllm.logger import logger as vllm_logger ...@@ -31,8 +31,8 @@ from vllm.logger import logger as vllm_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from triton_distributed._core import Client
from triton_distributed.runtime import ( from triton_distributed.runtime import (
Client,
DistributedRuntime, DistributedRuntime,
triton_endpoint, triton_endpoint,
triton_worker, triton_worker,
...@@ -126,7 +126,7 @@ class Processor(ProcessMixIn): ...@@ -126,7 +126,7 @@ class Processor(ProcessMixIn):
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
).model_dump_json(), ).model_dump_json(),
uuid.UUID(worker_id).int, int(worker_id),
) )
output = self.generate_responses(engine_generator) output = self.generate_responses(engine_generator)
......
...@@ -58,20 +58,21 @@ class Router: ...@@ -58,20 +58,21 @@ class Router:
@triton_endpoint(Tokens, WorkerId) @triton_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]: async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0 lora_id = 0
worker_id = "" worker_id = None
if self.routing_strategy == RoutingStrategy.PREFIX: if self.routing_strategy == RoutingStrategy.PREFIX:
try: try:
worker_id = await self.router.schedule(request.tokens, lora_id) worker_id = await self.router.schedule(request.tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e: except Exception as e:
vllm_logger.info(f"{e}") vllm_logger.info(f"{e}")
if "No worker found" in str(e): worker_id = None
worker_id = "" vllm_logger.exception(f"Error during worker selection: {e}")
else:
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}") vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
yield worker_id yield str(worker_id)
else: else:
# TODO: Do we implement round_robin and random here? # TODO: Do we implement round_robin and random here?
...@@ -113,8 +114,7 @@ async def worker(runtime: DistributedRuntime, args: Namespace): ...@@ -113,8 +114,7 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
+ "\n".join(f"id: {id}" for id in workers_client.endpoint_ids()) + "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
) )
# TODO Router is a fixed namespace separate from the others kv_listener = runtime.namespace("triton-init").component("vllm")
kv_listener = runtime.namespace("router").component(args.model_name)
await kv_listener.create_service() await kv_listener.create_service()
router_component = runtime.namespace("triton-init").component("router") router_component = runtime.namespace("triton-init").component("router")
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import asyncio import asyncio
import os import os
import uuid
from typing import AsyncIterator from typing import AsyncIterator
import uvloop import uvloop
...@@ -26,6 +25,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -26,6 +25,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import logger as vllm_logger from vllm.logger import logger as vllm_logger
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from triton_distributed.llm import KvMetricsPublisher
from triton_distributed.runtime import ( from triton_distributed.runtime import (
DistributedRuntime, DistributedRuntime,
triton_endpoint, triton_endpoint,
...@@ -40,10 +40,18 @@ class VllmEngine(BaseVllmEngine): ...@@ -40,10 +40,18 @@ class VllmEngine(BaseVllmEngine):
vLLM Inference Engine vLLM Inference Engine
""" """
def __init__(self, engine_args: AsyncEngineArgs): def __init__(
self, engine_args: AsyncEngineArgs, metrics_publisher: KvMetricsPublisher
):
self.metrics_publisher = metrics_publisher
self.engine_args = engine_args self.engine_args = engine_args
super().__init__(engine_args) super().__init__(engine_args)
async def initialize(self):
await super().initialize()
assert self.engine_client is not None, "engine_client was not initialized"
self.engine_client.set_metrics_publisher(self.metrics_publisher)
@triton_endpoint(vLLMGenerateRequest, MyRequestOutput) @triton_endpoint(vLLMGenerateRequest, MyRequestOutput)
async def generate(self, request) -> AsyncIterator: async def generate(self, request) -> AsyncIterator:
assert ( assert (
...@@ -74,21 +82,32 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -74,21 +82,32 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
""" """
Serve the triton-init.vllm.generate endpoint. Serve the triton-init.vllm.generate endpoint.
""" """
metrics_publisher = KvMetricsPublisher()
worker_component = runtime.namespace("triton-init").component("vllm") worker_component = runtime.namespace("triton-init").component("vllm")
await worker_component.create_service() await metrics_publisher.create_service(worker_component)
worker_endpoint = worker_component.endpoint("generate") worker_endpoint = worker_component.endpoint("generate")
# KV Publisher and Aggregator requires a UUID (str) VLLM_WORKER_ID = worker_endpoint.lease_id()
# KV Router requires a lease_id (int)
# This allows us to please both, until they are unified
# If VLLM_WORKER_ID is not set, KV Routing will fail
VLLM_WORKER_ID = uuid.UUID(int=worker_endpoint.lease_id())
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID) os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}") vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
vllm_engine = VllmEngine(engine_args) VLLM_KV_NAMESPACE = "triton-init"
os.environ["VLLM_KV_NAMESPACE"] = str(VLLM_KV_NAMESPACE)
VLLM_KV_COMPONENT = "vllm"
os.environ["VLLM_KV_COMPONENT"] = str(VLLM_KV_COMPONENT)
vllm_engine = VllmEngine(engine_args, metrics_publisher)
await vllm_engine.initialize() await vllm_engine.initialize()
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
metrics_publisher.publish(
0,
1024,
0,
1024,
)
await worker_endpoint.serve_endpoint(vllm_engine.generate) await worker_endpoint.serve_endpoint(vllm_engine.generate)
......
...@@ -22,13 +22,13 @@ use tracing as log; ...@@ -22,13 +22,13 @@ use tracing as log;
use uuid::Uuid; use uuid::Uuid;
use triton_distributed_llm::kv_router::{ use triton_distributed_llm::kv_router::{
indexer::compute_block_hash_for_seq, protocols::*, publisher::KvPublisher, indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher,
}; };
use triton_distributed_runtime::{DistributedRuntime, Worker}; use triton_distributed_runtime::{DistributedRuntime, Worker};
static WK: OnceCell<Worker> = OnceCell::new(); static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new(); static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls? // [FIXME] shouldn't the publisher be instance passing between API calls?
static KV_PUB: OnceCell<KvPublisher> = OnceCell::new(); static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
fn initialize_tracing() { fn initialize_tracing() {
// Sets up RUST_LOG environment variable for logging while KV Publishing // Sets up RUST_LOG environment variable for logging while KV Publishing
...@@ -49,11 +49,12 @@ pub enum TritonLlmResult { ...@@ -49,11 +49,12 @@ pub enum TritonLlmResult {
} }
/// # Safety /// # Safety
/// the model_name_c_str and worker_id_c_str are passed as pointers to C strings /// the namespace_c_str and component_c_str are passed as pointers to C strings
#[no_mangle] #[no_mangle]
pub unsafe extern "C" fn triton_llm_init( pub unsafe extern "C" fn triton_llm_init(
model_name_c_str: *const c_char, namespace_c_str: *const c_char,
worker_id_c_str: *const c_char, component_c_str: *const c_char,
worker_id: i64,
) -> TritonLlmResult { ) -> TritonLlmResult {
initialize_tracing(); initialize_tracing();
let wk = match WK.get_or_try_init(Worker::from_settings) { let wk = match WK.get_or_try_init(Worker::from_settings) {
...@@ -78,7 +79,7 @@ pub unsafe extern "C" fn triton_llm_init( ...@@ -78,7 +79,7 @@ pub unsafe extern "C" fn triton_llm_init(
} }
} }
}); });
let model_name = match unsafe { CStr::from_ptr(model_name_c_str) }.to_str() { let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
Ok(s) => s.to_string(), Ok(s) => s.to_string(),
Err(e) => { Err(e) => {
eprintln!("Failed to convert C string to Rust string: {:?}", e); eprintln!("Failed to convert C string to Rust string: {:?}", e);
...@@ -86,24 +87,17 @@ pub unsafe extern "C" fn triton_llm_init( ...@@ -86,24 +87,17 @@ pub unsafe extern "C" fn triton_llm_init(
} }
}; };
let worker_id_str = match unsafe { CStr::from_ptr(worker_id_c_str) }.to_str() { let component = match unsafe { CStr::from_ptr(component_c_str) }.to_str() {
Ok(s) => s, Ok(s) => s.to_string(),
Err(e) => { Err(e) => {
eprintln!("Failed to convert C string to Rust string: {:?}", e); eprintln!("Failed to convert C string to Rust string: {:?}", e);
return TritonLlmResult::ERR; return TritonLlmResult::ERR;
} }
}; };
let worker_id_uuid = match Uuid::parse_str(worker_id_str) {
Ok(uuid) => uuid,
Err(e) => {
eprintln!("Failed to parse worker_id as UUID: {:?}", e);
return TritonLlmResult::ERR;
}
};
match result { match result {
Ok(_) => match KV_PUB Ok(_) => match KV_PUB
.get_or_try_init(move || triton_create_kv_publisher(model_name, worker_id_uuid)) .get_or_try_init(move || triton_create_kv_publisher(namespace, component, worker_id))
{ {
Ok(_) => TritonLlmResult::OK, Ok(_) => TritonLlmResult::OK,
Err(e) => { Err(e) => {
...@@ -143,17 +137,18 @@ pub extern "C" fn triton_llm_load_publisher_create() -> TritonLlmResult { ...@@ -143,17 +137,18 @@ pub extern "C" fn triton_llm_load_publisher_create() -> TritonLlmResult {
// c++ executor api // c++ executor api
fn triton_create_kv_publisher( fn triton_create_kv_publisher(
model_name: String, namespace: String,
worker_id: Uuid, component: String,
) -> Result<KvPublisher, anyhow::Error> { worker_id: i64,
log::info!("Creating KV Publisher for model: {}", model_name); ) -> Result<KvEventPublisher, anyhow::Error> {
log::info!("Creating KV Publisher for model: {}", component);
match DRT match DRT
.get() .get()
.ok_or(anyhow::Error::msg("Could not get Distributed Runtime")) .ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
{ {
Ok(drt) => { Ok(drt) => {
let backend = drt.namespace("router")?.component(model_name)?; let backend = drt.namespace(namespace)?.component(component)?;
KvPublisher::new(drt.clone(), backend, worker_id) KvEventPublisher::new(drt.clone(), backend, worker_id)
} }
Err(e) => Err(e), Err(e) => Err(e),
} }
......
...@@ -64,6 +64,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -64,6 +64,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Client>()?; m.add_class::<Client>()?;
m.add_class::<AsyncResponseStream>()?; m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::kv::KvRouter>()?; m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::kv::KvMetricsPublisher>()?;
engine::add_to_module(m)?; engine::add_to_module(m)?;
......
...@@ -23,6 +23,7 @@ pub(crate) struct KvRouter { ...@@ -23,6 +23,7 @@ pub(crate) struct KvRouter {
#[pymethods] #[pymethods]
impl KvRouter { impl KvRouter {
#[new] #[new]
// [FXIME] 'drt' can be obtained from 'component'
fn new(drt: DistributedRuntime, component: Component) -> PyResult<Self> { fn new(drt: DistributedRuntime, component: Component) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime(); let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async { runtime.block_on(async {
...@@ -44,11 +45,64 @@ impl KvRouter { ...@@ -44,11 +45,64 @@ impl KvRouter {
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let router = self.inner.clone(); let router = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let uuid = router let worker_id = router
.schedule(&token_ids, lora_id) .schedule(&token_ids, lora_id)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(uuid.to_string()) Ok(worker_id)
}) })
} }
} }
#[pyclass]
pub(crate) struct KvMetricsPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvMetricsPublisher>,
}
#[pymethods]
impl KvMetricsPublisher {
#[new]
fn new() -> PyResult<Self> {
let inner = llm_rs::kv_router::publisher::KvMetricsPublisher::new().map_err(to_pyerr)?;
Ok(Self {
inner: inner.into(),
})
}
fn create_service<'p>(
&self,
py: Python<'p>,
component: Component,
) -> PyResult<Bound<'p, PyAny>> {
let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let _ = rs_publisher
.create_service(rs_component)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
fn publish<'p>(
&self,
py: Python<'p>,
request_active_slots: u64,
request_total_slots: u64,
kv_active_blocks: u64,
kv_total_blocks: u64,
) -> PyResult<()> {
self.inner
.publish(
llm_rs::kv_router::protocols::ForwardPassMetrics {
request_active_slots,
request_total_slots,
kv_active_blocks,
kv_total_blocks,
}
.into(),
)
.map_err(to_pyerr)
}
}
...@@ -128,7 +128,7 @@ class Client: ...@@ -128,7 +128,7 @@ class Client:
class KvRouter: class KvRouter:
""" """
The runtime object for a distributed NOVA applications A router will determine which worker should handle a given request.
""" """
... ...
...@@ -138,9 +138,36 @@ class KvRouter: ...@@ -138,9 +138,36 @@ class KvRouter:
Create a `KvRouter` object that is associated with the `component` Create a `KvRouter` object that is associated with the `component`
""" """
def schedule(self, token_ids: List[int], lora_id: int) -> str: def schedule(self, token_ids: List[int], lora_id: int) -> int:
""" """
Return the worker id that should handle the given token ids, Return the worker id that should handle the given token ids,
exception will be raised if there is no worker available. exception will be raised if there is no worker available.
""" """
... ...
class KvMetricsPublisher:
"""
A metrics publisher will provide KV metrics to the router.
"""
...
def __init__(self) -> None:
"""
Create a `KvMetricsPublisher` object
"""
def create_service(self, component: Component) -> None:
"""
Similar to Component.create_service, but only service created through
this method will interact with KV router of the same component.
"""
def publish(self, request_active_slots: int,
request_total_slots: int,
kv_active_blocks: int,
kv_total_blocks: int) -> None:
"""
Update the KV metrics being reported.
"""
...
...@@ -13,4 +13,5 @@ ...@@ -13,4 +13,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from triton_distributed._core import KvMetricsPublisher as KvMetricsPublisher
from triton_distributed._core import KvRouter as KvRouter from triton_distributed._core import KvRouter as KvRouter
...@@ -20,7 +20,8 @@ from typing import Any, AsyncGenerator, Callable, Type ...@@ -20,7 +20,8 @@ from typing import Any, AsyncGenerator, Callable, Type
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from triton_distributed._core import DistributedRuntime from triton_distributed._core import Client as Client
from triton_distributed._core import DistributedRuntime as DistributedRuntime
def triton_worker(): def triton_worker():
......
...@@ -23,14 +23,11 @@ use triton_distributed_runtime::{component::Component, DistributedRuntime}; ...@@ -23,14 +23,11 @@ use triton_distributed_runtime::{component::Component, DistributedRuntime};
pub mod indexer; pub mod indexer;
pub mod protocols; pub mod protocols;
pub mod publisher; pub mod publisher;
// [WIP] enable service_builder() through worker for metrics reporting
// pub mod worker;
mod scheduler; mod scheduler;
mod scoring; mod scoring;
use crate::kv_router::{ use crate::kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
protocols::KV_BLOCK_SIZE,
scheduler::{Endpoint, KvScheduler, Service}, scheduler::{Endpoint, KvScheduler, Service},
scoring::ProcessedEndpoints, scoring::ProcessedEndpoints,
}; };
...@@ -113,7 +110,7 @@ impl KvRouter { ...@@ -113,7 +110,7 @@ impl KvRouter {
} }
// [TODO] indexer needs to take 'lora_id' as parameter // [TODO] indexer needs to take 'lora_id' as parameter
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<String> { pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
// Extracting part of the code in KvRouter::generate() for only // Extracting part of the code in KvRouter::generate() for only
// the decision making part, routing is done by the caller // the decision making part, routing is done by the caller
let isl_tokens = token_ids.len(); let isl_tokens = token_ids.len();
...@@ -122,25 +119,8 @@ impl KvRouter { ...@@ -122,25 +119,8 @@ impl KvRouter {
.find_matches_for_request(token_ids.as_slice()) .find_matches_for_request(token_ids.as_slice())
.await?; .await?;
log::debug!("KV router overlap_scores: {:?}", overlap_scores); log::debug!("KV router overlap_scores: {:?}", overlap_scores);
// [FIXME] Python binding results in "endpoint subscriber shutdown" error, let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
// need to investigate whether it happens in pure rust as well and then Ok(worker_id)
// root cause it. Before that, not doing intelligent scheduling for rapid
// development..
// [FIXME] also need to fix that scheduler returns worker subject which is not
// the same as worker id (uuid). Seems like it adds additional annotation on top of uuid.
// Need to double check
// 'worker_subject' should be the same as worker id used for direct routing
// let worker_subject = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
let mut selected_worker_subject = Option::<String>::None;
for (worker_subject, overlap_score) in &overlap_scores.scores {
if ((*overlap_score as usize * KV_BLOCK_SIZE) as f64 / isl_tokens as f64) >= 0.5 {
selected_worker_subject = Some(worker_subject.to_string());
}
}
match selected_worker_subject {
None => Err(anyhow::anyhow!("No worker found")),
Some(worker_subject) => Ok(worker_subject),
}
} }
} }
...@@ -167,7 +147,7 @@ async fn collect_endpoints( ...@@ -167,7 +147,7 @@ async fn collect_endpoints(
.unwrap(); .unwrap();
// [FIXME] Endpoint is parsed from nats stats handler which may not include 'data' field // [FIXME] Endpoint is parsed from nats stats handler which may not include 'data' field
// if the service hasn't registered the handler. // if the service hasn't registered the handler. Need to be tolerant to this.
// Another option is to make sure the router is configured properly that // Another option is to make sure the router is configured properly that
// it listens to the right subject (where other publisher has stats). // it listens to the right subject (where other publisher has stats).
let services: Vec<Service> = values let services: Vec<Service> = values
......
...@@ -79,7 +79,7 @@ pub enum KvRouterError { ...@@ -79,7 +79,7 @@ pub enum KvRouterError {
} }
/// Identifier of a LLM worker which emits events to the router. /// Identifier of a LLM worker which emits events to the router.
pub type WorkerId = uuid::Uuid; pub type WorkerId = i64;
/// A shared reference to a [`RadixBlock`]. /// A shared reference to a [`RadixBlock`].
type SharedRadixBlock = Rc<RefCell<RadixBlock>>; type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
......
...@@ -13,20 +13,20 @@ ...@@ -13,20 +13,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::kv_router::{indexer::RouterEvent, protocols::KvCacheEvent, KV_EVENT_SUBJECT}; use crate::kv_router::{indexer::RouterEvent, protocols::*, KV_EVENT_SUBJECT};
use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing as log; use tracing as log;
use triton_distributed_runtime::{component::Component, DistributedRuntime, Result}; use triton_distributed_runtime::{component::Component, DistributedRuntime, Result};
use uuid::Uuid;
pub struct KvPublisher { pub struct KvEventPublisher {
tx: mpsc::UnboundedSender<KvCacheEvent>, tx: mpsc::UnboundedSender<KvCacheEvent>,
} }
impl KvPublisher { impl KvEventPublisher {
pub fn new(drt: DistributedRuntime, backend: Component, worker_id: Uuid) -> Result<Self> { pub fn new(drt: DistributedRuntime, backend: Component, worker_id: i64) -> Result<Self> {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let p = KvPublisher { tx }; let p = KvEventPublisher { tx };
start_publish_task(drt, backend, worker_id, rx); start_publish_task(drt, backend, worker_id, rx);
Ok(p) Ok(p)
...@@ -41,12 +41,10 @@ impl KvPublisher { ...@@ -41,12 +41,10 @@ impl KvPublisher {
fn start_publish_task( fn start_publish_task(
drt: DistributedRuntime, drt: DistributedRuntime,
backend: Component, backend: Component,
worker_id: Uuid, worker_id: i64,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>, mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
) { ) {
let client = drt.nats_client().client().clone(); let client = drt.nats_client().client().clone();
// [FIXME] service name is for metrics polling?
// let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT); let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
log::info!("Publishing KV Events to subject: {}", kv_subject); log::info!("Publishing KV Events to subject: {}", kv_subject);
...@@ -61,3 +59,37 @@ fn start_publish_task( ...@@ -61,3 +59,37 @@ fn start_publish_task(
} }
}); });
} }
pub struct KvMetricsPublisher {
tx: tokio::sync::watch::Sender<Arc<ForwardPassMetrics>>,
rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>,
}
impl KvMetricsPublisher {
pub fn new() -> Result<Self> {
let (tx, rx) = tokio::sync::watch::channel(Arc::new(ForwardPassMetrics::default()));
Ok(KvMetricsPublisher { tx, rx })
}
pub fn publish(
&self,
metrics: Arc<ForwardPassMetrics>,
) -> Result<(), tokio::sync::watch::error::SendError<Arc<ForwardPassMetrics>>> {
log::debug!("Publish metrics: {:?}", metrics);
self.tx.send(metrics)
}
pub async fn create_service(&self, component: Component) -> Result<()> {
let mut metrics_rx = self.rx.clone();
let _ = component
.service_builder()
.stats_handler(Some(Box::new(move |name, stats| {
log::debug!("[IN worker?] Stats for service {}: {:?}", name, stats);
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})))
.create()
.await?;
Ok(())
}
}
...@@ -17,8 +17,6 @@ use serde::{Deserialize, Serialize}; ...@@ -17,8 +17,6 @@ use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use std::cmp::min; use std::cmp::min;
use uuid::Uuid;
use crate::kv_router::indexer::OverlapScores; use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE}; pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE};
use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::scoring::ProcessedEndpoints;
...@@ -44,16 +42,17 @@ pub struct Endpoint { ...@@ -44,16 +42,17 @@ pub struct Endpoint {
} }
impl Endpoint { impl Endpoint {
pub fn worker_id(&self) -> Uuid { pub fn worker_id(&self) -> i64 {
Uuid::parse_str( i64::from_str_radix(
self.subject self.subject
.split(".") .split("-")
.last() .last()
.expect("invalid subject") .expect("invalid subject")
.to_string() .to_string()
.as_str(), .as_str(),
16,
) )
.expect("invalid uuid") .expect("invalid worker id")
} }
} }
...@@ -69,11 +68,11 @@ pub struct Service { ...@@ -69,11 +68,11 @@ pub struct Service {
pub struct SchedulingRequest { pub struct SchedulingRequest {
isl_tokens: usize, isl_tokens: usize,
overlap: OverlapScores, overlap: OverlapScores,
resp_tx: tokio::sync::oneshot::Sender<String>, resp_tx: tokio::sync::oneshot::Sender<i64>,
} }
impl SchedulingRequest { impl SchedulingRequest {
pub fn respond(self, worker_id: String) { pub fn respond(self, worker_id: i64) {
if self.resp_tx.send(worker_id).is_err() { if self.resp_tx.send(worker_id).is_err() {
tracing::trace!("failed to send response to requestor"); tracing::trace!("failed to send response to requestor");
} }
...@@ -174,7 +173,7 @@ impl KvScheduler { ...@@ -174,7 +173,7 @@ impl KvScheduler {
&self, &self,
overlap: OverlapScores, overlap: OverlapScores,
isl_tokens: usize, isl_tokens: usize,
) -> Result<String, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest { let request = SchedulingRequest {
isl_tokens, isl_tokens,
...@@ -199,7 +198,7 @@ impl KvScheduler { ...@@ -199,7 +198,7 @@ impl KvScheduler {
pub fn select_worker( pub fn select_worker(
workers: &mut ProcessedEndpoints, workers: &mut ProcessedEndpoints,
request: &SchedulingRequest, request: &SchedulingRequest,
) -> Result<String, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
// balance mode prioritizes balancing load across workers // balance mode prioritizes balancing load across workers
let balance_threshold: f64 = 0.1; let balance_threshold: f64 = 0.1;
let balance_mode = workers.load_std > balance_threshold * workers.load_avg; let balance_mode = workers.load_std > balance_threshold * workers.load_avg;
...@@ -227,6 +226,7 @@ pub fn select_worker( ...@@ -227,6 +226,7 @@ pub fn select_worker(
let kv_load_ratio = w.data.kv_active_blocks as f64 / w.data.kv_total_blocks as f64; let kv_load_ratio = w.data.kv_active_blocks as f64 / w.data.kv_total_blocks as f64;
let load_deviation = kv_load_ratio - workers.load_avg; let load_deviation = kv_load_ratio - workers.load_avg;
// [FIXME] multiple endpoints of the same worker cause out of bound error
let worker_id = workers.worker_ids[i]; let worker_id = workers.worker_ids[i];
let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x); let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x);
let overlap_score = overlap_score as usize * KV_BLOCK_SIZE; let overlap_score = overlap_score as usize * KV_BLOCK_SIZE;
...@@ -267,10 +267,10 @@ pub fn select_worker( ...@@ -267,10 +267,10 @@ pub fn select_worker(
Some(i) => { Some(i) => {
tracing::info!( tracing::info!(
"selected worker: {}; cost: {}", "selected worker: {}; cost: {}",
workers.endpoints[i].subject, workers.endpoints[i].worker_id(),
best_cost best_cost
); );
Ok(workers.endpoints[i].subject.clone()) Ok(workers.endpoints[i].worker_id())
} }
None => { None => {
tracing::debug!("all workers busy"); tracing::debug!("all workers busy");
......
...@@ -18,12 +18,11 @@ ...@@ -18,12 +18,11 @@
use std::collections::HashSet; use std::collections::HashSet;
use crate::kv_router::scheduler::Endpoint; use crate::kv_router::scheduler::Endpoint;
use uuid::Uuid;
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct ProcessedEndpoints { pub struct ProcessedEndpoints {
pub endpoints: Vec<Endpoint>, pub endpoints: Vec<Endpoint>,
pub worker_ids: Vec<Uuid>, pub worker_ids: Vec<i64>,
pub load_avg: f64, pub load_avg: f64,
pub load_std: f64, pub load_std: f64,
} }
...@@ -43,8 +42,8 @@ impl ProcessedEndpoints { ...@@ -43,8 +42,8 @@ impl ProcessedEndpoints {
/ load_values.len() as f64; / load_values.len() as f64;
let load_std = variance.sqrt(); let load_std = variance.sqrt();
let worker_ids: HashSet<Uuid> = endpoints.iter().map(|x| x.worker_id()).collect(); let worker_ids: HashSet<i64> = endpoints.iter().map(|x| x.worker_id()).collect();
let worker_ids: Vec<Uuid> = worker_ids.into_iter().collect(); let worker_ids: Vec<i64> = worker_ids.into_iter().collect();
ProcessedEndpoints { ProcessedEndpoints {
endpoints, endpoints,
......
...@@ -34,7 +34,6 @@ use std::time::SystemTime; ...@@ -34,7 +34,6 @@ use std::time::SystemTime;
use super::TokenIdType; use super::TokenIdType;
pub mod kv_routing;
pub mod llm_backend; pub mod llm_backend;
pub mod postprocessor; pub mod postprocessor;
pub mod preprocessor; pub mod preprocessor;
......
...@@ -74,9 +74,6 @@ addopts = [ ...@@ -74,9 +74,6 @@ addopts = [
"--mypy", "--mypy",
"--ignore-glob=*model.py", "--ignore-glob=*model.py",
# FIXME: Get relative/generic blob paths to work here # FIXME: Get relative/generic blob paths to work here
# Ignore rust<->python bindings until python package is built/installed in environment
"--ignore-glob=/workspace/python-wheel/python/triton_distributed_rs/*.py",
"--ignore-glob=/workspace/python-wheel/python/triton_distributed_rs/*.pyi",
] ]
xfail_strict = true xfail_strict = true
log_cli_level = "INFO" log_cli_level = "INFO"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment