Commit 8588e33a authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: Add KV publisher and receiver. Add KV aware routing example.


Signed-off-by: default avatarNeelay Shah <neelays@nvidia.com>
Co-authored-by: default avataraflowers <aflowers@nvidia.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
Co-authored-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent d8aada0b
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uuid
from argparse import Namespace
from enum import Enum
import uvloop
from common.protocol import Response, TokenizedRequest
from triton_distributed_rs import (
DistributedRuntime,
KvRouter,
triton_endpoint,
triton_worker,
)
from vllm.logger import logger as vllm_logger
class RoutingStrategy(Enum):
PREFIX = "prefix"
ROUND_ROBIN = "round_robin"
RANDOM = "random"
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
router,
workers_client,
routing_strategy: RoutingStrategy = RoutingStrategy.PREFIX,
):
vllm_logger.info(
f"Initializing KV Router with strategy: {routing_strategy.value}"
)
self.router = router
self.workers_client = workers_client
self.routing_strategy = routing_strategy
@triton_endpoint(TokenizedRequest, Response)
async def generate(self, request):
lora_id = 0
worker_id = ""
if self.routing_strategy == RoutingStrategy.PREFIX:
try:
worker_id = await self.router.schedule(request.tokens, lora_id)
except Exception as e:
vllm_logger.info(f"{e}")
if "No worker found" in str(e):
worker_id = ""
else:
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
if self.routing_strategy == RoutingStrategy.ROUND_ROBIN:
engine_generator = await self.workers_client.round_robin(
request.model_dump_json()
)
elif self.routing_strategy == RoutingStrategy.RANDOM or worker_id == "":
engine_generator = await self.workers_client.random(
request.model_dump_json()
)
else:
# extract back lease_id
engine_generator = await self.workers_client.direct(
request.model_dump_json(), uuid.UUID(worker_id).int
)
async for resp in engine_generator:
resp = resp.data() if hasattr(resp, "data") else resp
yield resp
@triton_worker()
async def worker(runtime: DistributedRuntime, args: Namespace):
workers_client = (
await runtime.namespace("triton-init")
.component("vllm")
.endpoint("generate_from_tokens")
.client()
)
vllm_logger.info("Waiting for workers to be ready")
await workers_client.wait_for_endpoints()
while len(workers_client.endpoint_ids()) < args.min_workers:
vllm_logger.info(
f"Waiting for more workers... Current: {len(workers_client.endpoint_ids())}, Required: {args.min_workers}"
)
await asyncio.sleep(5)
vllm_logger.info(
f"Required number of workers ({args.min_workers}) are ready:\n"
+ "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
)
# TODO Router is a fixed namespace separate from the others
kv_listener = runtime.namespace("router").component(
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
)
await kv_listener.create_service()
router_component = runtime.namespace("triton-init").component("router")
await router_component.create_service()
router = None
if args.routing_strategy == RoutingStrategy.PREFIX:
router = KvRouter(runtime, kv_listener)
endpoint = router_component.endpoint("generate")
await endpoint.serve_endpoint(
Router(router, workers_client, args.routing_strategy).generate
)
if __name__ == "__main__":
uvloop.install()
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--routing-strategy",
type=RoutingStrategy,
default=RoutingStrategy.PREFIX,
choices=list(RoutingStrategy),
help="Routing strategy to use",
)
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers required before proceeding",
)
args = parser.parse_args()
asyncio.run(worker(args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import uuid
from typing import Optional
import uvloop
import vllm
from common.parser import parse_vllm_args
from common.protocol import Request, Response, TokenizedRequest
from triton_distributed_rs import (
DistributedRuntime,
KvRouter,
triton_endpoint,
triton_worker,
)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import TokensPrompt
from vllm.logger import logger as vllm_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
vllm_logger.info(f"VLLM_KV_CAPI_PATH: {os.environ['VLLM_KV_CAPI_PATH']}")
class VllmEngine:
"""
Request handler for the generate endpoint
"""
def __init__(self, engine_args: AsyncEngineArgs, router: KvRouter):
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
self.router = router
self.tokenizer: Optional[AnyTokenizer] = None
# Pattern to initialize async object as python __init__ is not async
async def init(self):
self.tokenizer = await self.engine.get_tokenizer()
return self
@triton_endpoint(TokenizedRequest, Response)
async def generate_from_tokens(self, request):
tokens_prompt = TokensPrompt(prompt_token_ids=request.tokens)
sampling_params = vllm.SamplingParams(**request.sampling_params)
request_id = str(uuid.uuid4())
async for response in self.engine.generate(
tokens_prompt, sampling_params, request_id
):
yield response.outputs[0].text
@triton_endpoint(Request, Response)
async def generate_from_prompt(self, request):
sampling_params = vllm.SamplingParams(**request.sampling_params)
request_id = str(uuid.uuid4())
async for response in self.engine.generate(
request.prompt, sampling_params, request_id
):
yield response.outputs[0].text
@triton_endpoint(Request, Response)
async def preprocess(self, request):
if self.tokenizer is None:
raise RuntimeError("Tokenizer not initialized. Must run init().")
tokens = self.tokenizer.encode(request.prompt)
engine_generator = await self.router.generate(
TokenizedRequest(tokens=tokens, **request.model_dump()).model_dump_json()
)
async for resp in engine_generator:
yield resp.data()
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
worker_component = runtime.namespace("triton-init").component("vllm")
await worker_component.create_service()
preprocess_component = runtime.namespace("triton-init").component("preprocess")
await preprocess_component.create_service()
router_client = (
await runtime.namespace("triton-init")
.component("router")
.endpoint("generate")
.client()
)
worker_from_tokens_endpoint = worker_component.endpoint("generate_from_tokens")
worker_from_prompt_endpoint = worker_component.endpoint("generate")
preprocess_endpoint = preprocess_component.endpoint("generate")
# TODO Hack until we unify lease_id and worker_id
VLLM_WORKER_ID = uuid.UUID(int=worker_from_tokens_endpoint.lease_id())
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
vllm_engine = VllmEngine(engine_args, router_client)
vllm_engine = await vllm_engine.init()
await asyncio.gather(
worker_from_tokens_endpoint.serve_endpoint(vllm_engine.generate_from_tokens),
worker_from_prompt_endpoint.serve_endpoint(vllm_engine.generate_from_prompt),
preprocess_endpoint.serve_endpoint(vllm_engine.preprocess),
)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
asyncio.run(worker(engine_args))
...@@ -1637,6 +1637,12 @@ version = "0.3.31" ...@@ -1637,6 +1637,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.31" version = "0.3.31"
...@@ -2472,7 +2478,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2472,7 +2478,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
"windows-targets 0.48.5", "windows-targets 0.52.6",
] ]
[[package]] [[package]]
...@@ -2491,6 +2497,27 @@ dependencies = [ ...@@ -2491,6 +2497,27 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "libtriton-llm"
version = "0.1.1"
dependencies = [
"anyhow",
"async-once-cell",
"cbindgen",
"futures",
"libc",
"once_cell",
"serde",
"serde_json",
"tokio",
"tokio-stream",
"tracing",
"tracing-subscriber",
"triton-distributed",
"triton-llm",
"uuid 1.13.1",
]
[[package]] [[package]]
name = "linked-hash-map" name = "linked-hash-map"
version = "0.5.6" version = "0.5.6"
...@@ -3152,9 +3179,9 @@ dependencies = [ ...@@ -3152,9 +3179,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.20.2" version = "1.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e"
[[package]] [[package]]
name = "onig" name = "onig"
...@@ -3909,6 +3936,12 @@ version = "0.8.5" ...@@ -3909,6 +3936,12 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "relative-path"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2"
[[package]] [[package]]
name = "reqwest" name = "reqwest"
version = "0.12.12" version = "0.12.12"
...@@ -3976,6 +4009,35 @@ dependencies = [ ...@@ -3976,6 +4009,35 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "rstest"
version = "0.18.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97eeab2f3c0a199bc4be135c36c924b6590b88c377d416494288c14f2db30199"
dependencies = [
"futures",
"futures-timer",
"rstest_macros",
"rustc_version",
]
[[package]]
name = "rstest_macros"
version = "0.18.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d428f8247852f894ee1be110b375111b586d4fa431f6c46e64ba5a0dcccbe605"
dependencies = [
"cfg-if 1.0.0",
"glob",
"proc-macro2",
"quote",
"regex",
"relative-path",
"rustc_version",
"syn 2.0.98",
"unicode-ident",
]
[[package]] [[package]]
name = "rustc-demangle" name = "rustc-demangle"
version = "0.1.24" version = "0.1.24"
...@@ -5233,6 +5295,7 @@ dependencies = [ ...@@ -5233,6 +5295,7 @@ dependencies = [
"proptest", "proptest",
"regex", "regex",
"reqwest", "reqwest",
"rstest",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.11", "thiserror 2.0.11",
...@@ -5244,6 +5307,7 @@ dependencies = [ ...@@ -5244,6 +5307,7 @@ dependencies = [
"unicode-segmentation", "unicode-segmentation",
"uuid 1.13.1", "uuid 1.13.1",
"validator", "validator",
"xxhash-rust",
] ]
[[package]] [[package]]
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
[workspace] [workspace]
members = [ members = [
"triton-llm", "triton-llm",
"libtriton-llm",
] ]
resolver = "2" resolver = "2"
...@@ -47,3 +48,4 @@ tokio-util = { version = "0.7", features = ["codec", "net"] } ...@@ -47,3 +48,4 @@ tokio-util = { version = "0.7", features = ["codec", "net"] }
tracing = { version = "0.1" } tracing = { version = "0.1" }
validator = { version = "0.20.0", features = ["derive"] } validator = { version = "0.20.0", features = ["derive"] }
uuid = { version = "1", features = ["v4", "serde"] } uuid = { version = "1", features = ["v4", "serde"] }
xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] }
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[package]
name = "libtriton-llm"
version = "0.1.1"
edition = "2021"
authors = ["NVIDIA"]
license = "Apache-2.0"
homepage = "https://github.com/triton-inference-server/triton_distributed"
repository = "https://github.com/triton-inference-server/triton_distributed"
[lib]
name = "triton_llm_capi"
crate-type = ["cdylib"]
[build-dependencies]
cbindgen = "0.27"
[dependencies]
triton-llm = { path = "../triton-llm" }
triton-distributed = { workspace = true }
anyhow = { version = "1" }
futures = "0.3"
once_cell = "1"
serde = "1"
serde_json = "1.0.138"
tokio = { version = "1", features = ["full"] }
tokio-stream = "0"
tracing = "0"
libc = "0.2"
uuid = { version = "1", features = ["v4", "serde"] }
async-once-cell = "0.5.4"
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
\ No newline at end of file
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::env;
use std::path::Path;
fn main() {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let header_path = Path::new(&crate_dir)
.join("include")
.join("nvidia")
.join("triton_llm")
.join("llm_engine.h");
cbindgen::generate(crate_dir)
.expect("Unable to generate bindings")
.write_to_file(header_path);
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
language = "C++"
cpp_compat = true
include_guard = "__NVIDIA_TRITON_LLM_API__"
[enum]
rename_variants = "none"
prefix_with_name = false
enum_class = false
[export]
include = ["TritonLlmResult", "triton_llm_init", "triton_llm_shutdown"]
[export.rename]
"TritonLlmResult" = "triton_llm_result_t"
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_once_cell::OnceCell as AsyncOnceCell;
use libc::c_char;
use once_cell::sync::OnceCell;
use std::ffi::CStr;
use uuid::Uuid;
use std::sync::atomic::{AtomicU32, Ordering};
use tracing as log;
use triton_distributed::{DistributedRuntime, Worker};
use triton_llm::kv_router::{
indexer::compute_block_hash_for_seq, protocols::*, publisher::KvPublisher,
};
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
static KV_PUB: OnceCell<KvPublisher> = OnceCell::new();
fn initialize_tracing() {
// Sets up RUST_LOG environment variable for logging while KV Publishing
// Example: os.environ["RUST_LOG"] = "debug"
let subscriber = tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.finish();
tracing::subscriber::set_global_default(subscriber)
.expect("setting default subscriber failed");
log::debug!("Tracing initialized");
}
#[repr(u32)]
pub enum TritonLlmResult {
OK = 0,
ERR = 1,
}
/// # Safety
/// the model_name_c_str and worker_id_c_str are passed as pointers to C strings
#[no_mangle]
pub unsafe extern "C" fn triton_llm_init(
model_name_c_str: *const c_char,
worker_id_c_str: *const c_char,
) -> TritonLlmResult {
initialize_tracing();
let wk = match WK.get_or_try_init(Worker::from_settings) {
Ok(wk) => wk.clone(),
Err(e) => {
eprintln!("Failed to initialize runtime: {:?}", e);
return TritonLlmResult::ERR;
}
};
let rt = wk.runtime();
let secondary = rt.secondary().clone();
let result = secondary.block_on(async {
// Initialize the distributed runtime
match DRT
.get_or_try_init(async { DistributedRuntime::from_settings(rt.clone()).await })
.await
{
Ok(_) => Ok(()),
Err(e) => {
eprintln!("Failed to initialize distributed runtime: {:?}", e);
Err(TritonLlmResult::ERR)
}
}
});
let model_name = match unsafe { CStr::from_ptr(model_name_c_str) }.to_str() {
Ok(s) => s.to_string(),
Err(e) => {
eprintln!("Failed to convert C string to Rust string: {:?}", e);
return TritonLlmResult::ERR;
}
};
let worker_id_str = match unsafe { CStr::from_ptr(worker_id_c_str) }.to_str() {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to convert C string to Rust string: {:?}", e);
return TritonLlmResult::ERR;
}
};
let worker_id_uuid = match Uuid::parse_str(worker_id_str) {
Ok(uuid) => uuid,
Err(e) => {
eprintln!("Failed to parse worker_id as UUID: {:?}", e);
return TritonLlmResult::ERR;
}
};
match result {
Ok(_) => match KV_PUB
.get_or_try_init(move || triton_create_kv_publisher(model_name, worker_id_uuid))
{
Ok(_) => TritonLlmResult::OK,
Err(e) => {
eprintln!("Failed to initialize distributed runtime: {:?}", e);
TritonLlmResult::ERR
}
},
Err(e) => e,
}
}
#[no_mangle]
pub extern "C" fn triton_llm_shutdown() -> TritonLlmResult {
let wk = match WK.get() {
Some(wk) => wk,
None => {
eprintln!("Runtime not initialized");
return TritonLlmResult::ERR;
}
};
wk.runtime().shutdown();
TritonLlmResult::OK
}
#[no_mangle]
pub extern "C" fn triton_llm_load_publisher_create() -> TritonLlmResult {
TritonLlmResult::OK
}
// instantiate a kv publisher
// this will bring up the task to publish and the channels to await publishing events
// the [`triton_kv_publish_store_event`] call will use a handle to the publisher to send events
// store and the [`triton_kv_event_create_removed`] will create remove events
// these call mus be driving by external c++ threads that are consuming the kv events from the
// c++ executor api
fn triton_create_kv_publisher(
model_name: String,
worker_id: Uuid,
) -> Result<KvPublisher, anyhow::Error> {
log::info!("Creating KV Publisher for model: {}", model_name);
match DRT
.get()
.ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
{
Ok(drt) => {
let backend = drt.namespace("router")?.component(model_name)?;
KvPublisher::new(drt.clone(), backend, worker_id)
}
Err(e) => Err(e),
}
}
fn kv_event_create_stored_block_from_parts(
block_hash: u64,
token_ids: *const u32,
num_tokens: usize,
_lora_id: u64,
) -> KvCacheStoredBlockData {
let tokens_hash =
compute_block_hash_for_seq(unsafe { std::slice::from_raw_parts(token_ids, num_tokens) })[0];
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash),
tokens_hash,
}
}
static WARN_COUNT: AtomicU32 = AtomicU32::new(0);
fn kv_event_create_stored_from_parts(
event_id: u64,
token_ids: *const u32,
num_block_tokens: *const usize,
block_ids: *const u64,
num_blocks: usize,
parent_hash: Option<u64>,
lora_id: u64,
) -> KvCacheEvent {
let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();
let mut token_offset: usize = 0;
for block_idx in 0..num_blocks {
let block_hash = unsafe { *block_ids.offset(block_idx.try_into().unwrap()) };
let tokens = unsafe { token_ids.offset(token_offset.try_into().unwrap()) };
let num_toks = unsafe { *num_block_tokens.offset(block_idx.try_into().unwrap()) };
// compute hash only apply to full block (KV_BLOCK_SIZE token)
if num_toks != 64 {
if WARN_COUNT.fetch_update(
Ordering::SeqCst,
Ordering::SeqCst,
|c| if c < 3 { Some(c + 1) } else { None }).is_ok() {
log::warn!("Block size must be 64 tokens to be published. Block size is: {}", num_toks);
}
break;
}
token_offset += num_toks;
blocks.push(kv_event_create_stored_block_from_parts(
block_hash, tokens, num_toks, lora_id,
));
}
KvCacheEvent {
data: KvCacheEventData::Stored(KvCacheStoreData {
blocks,
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
}),
event_id,
}
}
fn kv_event_create_removed_from_parts(
event_id: u64,
block_ids: *const u64,
num_blocks: usize,
) -> KvCacheEvent {
let block_hashes: Vec<ExternalSequenceBlockHash> =
unsafe { std::slice::from_raw_parts(block_ids, num_blocks) }
.to_vec()
.iter()
.map(|&v| ExternalSequenceBlockHash(v))
.collect();
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
}
}
/// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks
/// has a parent hash or not. nullptr is used to represent no parent hash
#[no_mangle]
pub unsafe extern "C" fn triton_kv_event_publish_stored(
event_id: u64,
token_ids: *const u32,
num_block_tokens: *const usize,
block_ids: *const u64,
num_blocks: usize,
parent_hash: *const u64,
lora_id: u64,
) -> TritonLlmResult {
let publisher = KV_PUB.get().unwrap();
let parent_hash = {
if parent_hash.is_null() {
None
} else {
Some(unsafe { *parent_hash })
}
};
let event = kv_event_create_stored_from_parts(
event_id,
token_ids,
num_block_tokens,
block_ids,
num_blocks,
parent_hash,
lora_id,
);
match publisher.publish(event) {
Ok(_) => TritonLlmResult::OK,
Err(e) => {
eprintln!("Error publishing stored kv event {:?}", e);
TritonLlmResult::ERR
}
}
}
#[no_mangle]
pub extern "C" fn triton_kv_event_publish_removed(
event_id: u64,
block_ids: *const u64,
num_blocks: usize,
) -> TritonLlmResult {
let publisher = KV_PUB.get().unwrap();
let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks);
match publisher.publish(event) {
Ok(_) => TritonLlmResult::OK,
Err(e) => {
eprintln!("Error publishing removed kv event {:?}", e);
TritonLlmResult::ERR
}
}
}
// #[no_mangle]
// pub extern "C" fn triton_kv_publish_store_event(
// event_id: u64,
// token_ids: *const u32,
// num_tokens: usize,
// lora_id: u64,
// ) -> TritonLlmResult {
// // if event.is_null() || token_ids.is_null() {
// // return tritonKvErrorType::INVALID_TOKEN_IDS;
// // }
// // let tokens = unsafe { std::slice::from_raw_parts(token_ids, num_tokens) }.to_vec();
// // let new_event = Box::new(KvCacheStoreData {
// // event_id,
// // lora_id,
// // token_ids: tokens,
// // block_hashes: Vec::new(),
// // });
// // unsafe { *event = Box::into_raw(new_event) };
// TritonLlmResult::OK
// }
// #[no_mangle]
// pub extern "C" fn triton_kv_event_create_removed(
// event_id: u64,
// block_hashes: *const u64,
// num_hashes: usize,
// ) -> TritonLlmResult {
// // if event.is_null() || block_hashes.is_null() {
// // return -1;
// // }
// // let hashes = unsafe { std::slice::from_raw_parts(block_hashes, num_hashes) }.to_vec();
// // let new_event = Box::new(KvCacheRemoveData {
// // event_id,
// // lora_id: 0,
// // token_ids: Vec::new(),
// // block_hashes: hashes,
// // });
// // unsafe { *event = Box::into_raw(new_event) };
// // 0
// TritonLlmResult::OK
// }
// /// create load publisher object and return a handle
// /// load publisher will instantiate the nats service and tie its stats handler to
// /// a watch channel receiver. the watch channel sender will be attach to the
// /// handle and calls to [`triton_load_stats_publish`] issue the stats to the watch t
// pub extern "C" fn triton_load_publisher_create() -> *mut LoadPublisher {
// // let publisher = Box::new(LoadPublisher::new());
// // Box::into_raw(publisher)
// }
// pub extern "C" fn triton_load_stats_publish(
// publisher: *mut LoadPublisher,
// active_slots: u64,
// total_slots: u64,
// active_kv: u64,
// total_kv: u64,
// ) {
// // let publisher = unsafe { &mut *publisher };
// }
...@@ -46,6 +46,7 @@ tokio-util = { workspace = true } ...@@ -46,6 +46,7 @@ tokio-util = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
validator = { workspace = true } validator = { workspace = true }
uuid = { workspace = true } uuid = { workspace = true }
xxhash-rust = { workspace = true }
# protocols # protocols
chrono = { version = "0.4" } chrono = { version = "0.4" }
...@@ -66,3 +67,4 @@ mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "5e6 ...@@ -66,3 +67,4 @@ mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "5e6
insta = { version = "1.41", features = ["glob", "json", "redactions"]} insta = { version = "1.41", features = ["glob", "json", "redactions"]}
proptest = "1.5.0" proptest = "1.5.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
rstest = "0.18.2"
\ No newline at end of file
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Result;
use futures::stream::StreamExt;
use std::{sync::Arc, time::Duration};
use tokio_util::sync::CancellationToken;
use tracing as log;
use triton_distributed::{component::Component, DistributedRuntime};
pub mod indexer;
pub mod protocols;
pub mod publisher;
// [WIP] enable service_builder() through worker for metrics reporting
// pub mod worker;
mod scheduler;
mod scoring;
use crate::kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
protocols::KV_BLOCK_SIZE,
scheduler::{Endpoint, KvScheduler, Service},
scoring::ProcessedEndpoints,
};
// this should be discovered from the backend
pub const KV_EVENT_SUBJECT: &str = "kv_events";
pub struct KvRouter {
// properties of request plane
// maybe rolled up into the generic object or not
service_name: String,
cancellation_token: CancellationToken,
scheduler: KvScheduler,
indexer: KvIndexer,
}
impl KvRouter {
pub async fn from_runtime(
runtime: DistributedRuntime,
backend: Component,
) -> Result<Arc<Self>> {
let nats_client = runtime.nats_client();
let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
log::info!("Component Service Name {}", service_name);
log::info!("KV Subject {}", kv_subject);
Self::new(nats_client, service_name, kv_subject).await
}
pub async fn new(
nats_client: triton_distributed::transports::nats::Client,
service_name: String,
kv_subject: String,
) -> Result<Arc<Self>> {
let cancellation_token = CancellationToken::new();
let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128);
tokio::spawn(collect_endpoints(
nats_client.clone(),
service_name.clone(),
ep_tx,
cancellation_token.clone(),
));
let indexer = KvIndexer::new(cancellation_token.clone());
let scheduler = KvScheduler::start(ep_rx).await?;
log::debug!("subscribing to kv events: {}", kv_subject);
let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?;
let kv_events_tx = indexer.event_sender();
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
let event: RouterEvent = serde_json::from_slice(&event.payload).unwrap();
log::debug!("received kv event: {:?}", event);
if let Err(e) = kv_events_tx.send(event).await {
log::trace!("failed to send kv event to indexer; shutting down: {:?}", e);
}
}
});
Ok(Arc::new(Self {
service_name,
cancellation_token,
scheduler,
indexer,
}))
}
pub fn cancellation_token(&self) -> CancellationToken {
self.cancellation_token.clone()
}
pub fn service_name(&self) -> &str {
&self.service_name
}
// [TODO] indexer needs to take 'lora_id' as parameter
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<String> {
// Extracting part of the code in KvRouter::generate() for only
// the decision making part, routing is done by the caller
let isl_tokens = token_ids.len();
let overlap_scores = self
.indexer
.find_matches_for_request(token_ids.as_slice())
.await?;
log::debug!("KV router overlap_scores: {:?}", overlap_scores);
// [FIXME] Python binding results in "endpoint subscriber shutdown" error,
// need to investigate whether it happens in pure rust as well and then
// root cause it. Before that, not doing intelligent scheduling for rapid
// development..
// [FIXME] also need to fix that scheduler returns worker subject which is not
// the same as worker id (uuid). Seems like it adds additional annotation on top of uuid.
// Need to double check
// 'worker_subject' should be the same as worker id used for direct routing
// let worker_subject = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
let mut selected_worker_subject = Option::<String>::None;
for (worker_subject, overlap_score) in &overlap_scores.scores {
if ((*overlap_score as usize * KV_BLOCK_SIZE) as f64 / isl_tokens as f64) >= 0.5 {
selected_worker_subject = Some(worker_subject.to_string());
}
}
match selected_worker_subject {
None => Err(anyhow::anyhow!("No worker found")),
Some(worker_subject) => Ok(worker_subject),
}
}
}
async fn collect_endpoints(
nats_client: triton_distributed::transports::nats::Client,
service_name: String,
ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>,
cancel: CancellationToken,
) {
loop {
tokio::select! {
_ = cancel.cancelled() => {
log::debug!("cancellation token triggered");
break;
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
log::trace!("collecting endpoints for service: {}", service_name);
}
}
let values = nats_client
.get_endpoints(&service_name, Duration::from_secs(1))
.await
.unwrap();
// [FIXME] Endpoint is parsed from nats stats handler which may not include 'data' field
// if the service hasn't registered the handler.
// Another option is to make sure the router is configured properly that
// it listens to the right subject (where other publisher has stats).
let services: Vec<Service> = values
.into_iter()
.filter(|v| !v.is_empty())
.map(|v| {
let value: serde_json::Value = serde_json::from_slice(&v).unwrap();
log::trace!("service value: {:?}", value);
serde_json::from_slice(&v).unwrap()
})
.collect();
let endpoints: Vec<Endpoint> = services.into_iter().flat_map(|s| s.endpoints).collect();
log::trace!(
"found {} endpoints for service: {}",
endpoints.len(),
service_name
);
let processed = ProcessedEndpoints::new(endpoints);
// process endpoints into
if ep_tx.send(processed).await.is_err() {
log::trace!("failed to send processed endpoints; shutting down");
break;
}
}
}
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
// Currently hard-coding the block size to be 64 tokens, and
// assuming the LLM framework aligns with this size.
// The KV publisher and subscriber conveys hash values of the tokens,
// for performance reason, therefore the block size needs to be consistent
// so that the computed hash value is the same on both sizes.
pub const KV_BLOCK_SIZE: usize = 64;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ForwardPassMetrics {
pub request_active_slots: u64,
pub request_total_slots: u64,
pub kv_active_blocks: u64,
pub kv_total_blocks: u64,
}
/// A [`BlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional
/// lora_id of a block.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub struct LocalBlockHash(pub u64);
/// A sequence aware hash of a block where the hash is computed from the tokens_ids, extra_token_ids
/// and the optional lora_id of a block, PLUS the hash of the parent block.
///
/// In this case, the hashing function is external and unknown.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub struct ExternalSequenceBlockHash(pub u64);
/// Represents a collection of cache events and a shutdown flag.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheEvents {
/// A list of cache events.
pub events: Vec<KvCacheEvent>,
/// A flag indicating whether the cache is shutting down.
pub shutdown: bool,
}
/// Represents a single cache event with an ID and associated data.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheEvent {
/// The unique identifier of the event.
pub event_id: u64,
/// The data associated with the event.
pub data: KvCacheEventData,
}
/// Represents the data associated with a cache event.
///
/// Data is either stored or removed.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum KvCacheEventData {
/// Data for a stored cache event.
Stored(KvCacheStoreData),
/// Data for a removed cache event.
Removed(KvCacheRemoveData),
}
/// Represents the data associated with a stored cache event.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheStoreData {
/// The optional hash of the parent block.
pub parent_hash: Option<ExternalSequenceBlockHash>,
/// A list of stored blocked data.
pub blocks: Vec<KvCacheStoredBlockData>,
}
/// Represents data for a stored block.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheStoredBlockData {
/// The hash of the block.
pub block_hash: ExternalSequenceBlockHash,
/// The hash of the tokens in the block.
pub tokens_hash: LocalBlockHash,
}
/// Represents the data associated with a removed cache event.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheRemoveData {
/// A list of block hashes to remove.
pub block_hashes: Vec<ExternalSequenceBlockHash>,
}
impl Serialize for LocalBlockHash {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u64(self.0)
}
}
impl<'de> Deserialize<'de> for LocalBlockHash {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = u64::deserialize(deserializer)?;
Ok(LocalBlockHash(value))
}
}
impl Serialize for ExternalSequenceBlockHash {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u64(self.0)
}
}
impl<'de> Deserialize<'de> for ExternalSequenceBlockHash {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = u64::deserialize(deserializer)?;
Ok(ExternalSequenceBlockHash(value))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_local_block_hash_serialization() {
let hash = LocalBlockHash(12345);
let serialized = serde_json::to_string(&hash).unwrap();
assert_eq!(serialized, "12345");
let deserialized: LocalBlockHash = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, hash);
}
#[test]
fn test_external_sequence_block_hash_serialization() {
let hash = ExternalSequenceBlockHash(67890);
let serialized = serde_json::to_string(&hash).unwrap();
assert_eq!(serialized, "67890");
let deserialized: ExternalSequenceBlockHash = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, hash);
}
#[test]
fn test_kv_cache_events_serialization() {
let event_data = KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(2),
tokens_hash: LocalBlockHash(3),
}],
});
let event = KvCacheEvent {
event_id: 1,
data: event_data,
};
let events = KvCacheEvents {
events: vec![event],
shutdown: false,
};
let serialized = serde_json::to_string(&events).unwrap();
let deserialized: KvCacheEvents = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.events.len(), 1);
assert_eq!(deserialized.events[0].event_id, 1);
if let KvCacheEventData::Stored(store_data) = &deserialized.events[0].data {
assert_eq!(store_data.parent_hash.unwrap().0, 1);
assert_eq!(store_data.blocks.len(), 1);
assert_eq!(store_data.blocks[0].block_hash.0, 2);
assert_eq!(store_data.blocks[0].tokens_hash.0, 3);
} else {
panic!("Expected KvCacheEventData::Stored variant");
}
assert!(!deserialized.shutdown);
}
#[test]
fn test_kv_cache_remove_data_serialization() {
let remove_data = KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(4), ExternalSequenceBlockHash(5)],
};
let serialized = serde_json::to_string(&remove_data).unwrap();
let deserialized: KvCacheRemoveData = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.block_hashes.len(), 2);
assert_eq!(deserialized.block_hashes[0].0, 4);
assert_eq!(deserialized.block_hashes[1].0, 5);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::kv_router::{indexer::RouterEvent, protocols::KvCacheEvent, KV_EVENT_SUBJECT};
use tokio::sync::mpsc;
use triton_distributed::{component::Component, DistributedRuntime, Result};
use uuid::Uuid;
use tracing as log;
pub struct KvPublisher {
tx: mpsc::UnboundedSender<KvCacheEvent>,
}
impl KvPublisher {
pub fn new(drt: DistributedRuntime, backend: Component, worker_id: Uuid) -> Result<Self> {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let p = KvPublisher { tx };
start_publish_task(drt, backend, worker_id, rx);
Ok(p)
}
pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> {
log::debug!("Publish event: {:?}", event);
self.tx.send(event)
}
}
fn start_publish_task(
drt: DistributedRuntime,
backend: Component,
worker_id: Uuid,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
) {
let client = drt.nats_client().client().clone();
// [FIXME] service name is for metrics polling?
// let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
log::info!("Publishing KV Events to subject: {}", kv_subject);
_ = drt.runtime().secondary().spawn(async move {
while let Some(event) = rx.recv().await {
let router_event = RouterEvent::new(worker_id, event);
let data = serde_json::to_string(&router_event).unwrap();
client
.publish(kv_subject.to_string(), data.into())
.await
.unwrap();
}
});
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::cmp::min;
use tracing as log;
use uuid::Uuid;
use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE};
use crate::kv_router::scoring::ProcessedEndpoints;
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
#[error("no endpoints aviailable to route work")]
NoEndpoints,
#[error("endpoints existed, but no valid routes were found")]
NoRoutes,
#[error("all workers busy")]
AllWorkersBusy,
#[error("endpoint subscriber shutdown")]
SubscriberShutdown,
#[error("scheduler offline")]
SchedulerOffline,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
pub name: String,
pub subject: String,
pub data: ForwardPassMetrics,
}
impl Endpoint {
pub fn worker_id(&self) -> Uuid {
Uuid::parse_str(
self.subject
.split(".")
.last()
.expect("invalid subject")
.to_string()
.as_str(),
)
.expect("invalid uuid")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Service {
pub name: String,
pub id: String,
pub version: String,
pub started: String,
pub endpoints: Vec<Endpoint>,
}
pub struct SchedulingRequest {
isl_tokens: usize,
overlap: OverlapScores,
resp_tx: tokio::sync::oneshot::Sender<String>,
}
impl SchedulingRequest {
pub fn respond(self, worker_id: String) {
if self.resp_tx.send(worker_id).is_err() {
log::trace!("failed to send response to requestor");
}
}
}
pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
}
impl KvScheduler {
pub async fn start(
endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>,
) -> Result<Self, KvSchedulerError> {
let mut endpoints_rx = endpoints_rx;
log::trace!("awaiting the start of the background endpoint subscriber");
let mut endpoints = match endpoints_rx.recv().await {
Some(endpoints) => endpoints,
None => {
return Err(KvSchedulerError::SubscriberShutdown);
}
};
// Channel to accept new scheduling requests
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16);
log::debug!("scheduler starting");
// Background task to handle scheduling requests
tokio::spawn(async move {
let mut request: SchedulingRequest;
let mut request_rx = request_rx;
log::debug!("scheduler background task started");
'outer: loop {
request = tokio::select! {
biased;
new_request = request_rx.recv() => {
match new_request {
Some(new_request) => {
log::trace!("received request to be scheduled");
new_request
},
None => {
log::trace!("scheduler shutdown");
break 'outer;
}
}
}
new_endpoints = endpoints_rx.recv() => {
match new_endpoints {
Some(new_endpoints) => {
log::trace!("updated endpoints");
endpoints = new_endpoints;
continue 'outer;
}
None => {
log::trace!("endpoint subscriber shutdown");
break 'outer;
}
}
}
};
log::debug!("selected");
loop {
match select_worker(endpoints.borrow_mut(), &request) {
Ok(worker_id) => {
request.respond(worker_id);
continue 'outer;
}
Err(KvSchedulerError::AllWorkersBusy) => {
log::trace!("all workers busy; waiting for more capacity");
endpoints = match endpoints_rx.recv().await {
Some(endpoints) => endpoints,
None => {
log::trace!("endpoint subscriber shutdown");
break 'outer;
}
};
}
Err(e) => {
log::error!("error scheduling request: {:?}", e);
break 'outer;
}
}
}
}
log::trace!("background endpoint subscriber shutting down");
});
Ok(KvScheduler { request_tx })
}
pub async fn schedule(
&self,
overlap: OverlapScores,
isl_tokens: usize,
) -> Result<String, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
isl_tokens,
overlap,
resp_tx,
};
log::debug!("before sending request");
self.request_tx
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
log::debug!("after sending request");
let res = resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
log::debug!("after receiving response");
Ok(res)
}
}
pub fn select_worker(
workers: &mut ProcessedEndpoints,
request: &SchedulingRequest,
) -> Result<String, KvSchedulerError> {
// balance mode prioritizes balancing load across workers
let balance_threshold: f64 = 0.1;
let balance_mode = workers.load_std > balance_threshold * workers.load_avg;
// Determine alpha based on mode
let alpha = if balance_mode { 0.7 } else { 0.3 };
let gamma = 0.1; // example tuning param
// Compute each worker's score
let mut best_index = None;
let mut best_cost = f64::INFINITY;
if workers.endpoints.is_empty() {
return Err(KvSchedulerError::NoEndpoints);
}
for (i, w) in workers.endpoints.iter().enumerate() {
// Exclude workers that are at capacity
if w.data.request_active_slots >= w.data.request_total_slots
|| w.data.kv_active_blocks >= w.data.kv_total_blocks
{
continue;
}
let kv_load_ratio = w.data.kv_active_blocks as f64 / w.data.kv_total_blocks as f64;
let load_deviation = kv_load_ratio - workers.load_avg;
let worker_id = workers.worker_ids[i];
let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x);
let overlap_score = overlap_score as usize * KV_BLOCK_SIZE;
let new_tokens = request.isl_tokens.saturating_sub(overlap_score);
let normalized_new_tokens = new_tokens as f64 / request.isl_tokens as f64;
let request_load_ratio =
w.data.request_active_slots as f64 / w.data.request_total_slots as f64;
// cost = alpha * load_deviation + (1 - alpha)*normalized_new_tokens + gamma * request_load_ratio
let cost = alpha * load_deviation
+ (1.0 - alpha) * normalized_new_tokens
+ gamma * request_load_ratio;
log::debug!("worker: {}; load_deviation: {}; normalized new blocks: {}; request_load_ratio: {} cost: {}",
worker_id,
load_deviation,
normalized_new_tokens,
request_load_ratio,
cost
);
if cost < best_cost {
best_cost = cost;
best_index = Some(i);
}
}
if let Some(best_index) = best_index {
let total_blocks = min(request.isl_tokens / KV_BLOCK_SIZE, 1);
workers.endpoints[best_index].data.request_active_slots += 1;
workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64;
}
match best_index {
Some(i) => {
log::info!(
"selected worker: {}; cost: {}",
workers.endpoints[i].subject,
best_cost
);
Ok(workers.endpoints[i].subject.clone())
}
None => {
log::debug!("all workers busy");
Err(KvSchedulerError::AllWorkersBusy)
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Scoring functions for the KV router.
use std::collections::HashSet;
use crate::kv_router::scheduler::Endpoint;
use uuid::Uuid;
#[derive(Debug, Default)]
pub struct ProcessedEndpoints {
pub endpoints: Vec<Endpoint>,
pub worker_ids: Vec<Uuid>,
pub load_avg: f64,
pub load_std: f64,
}
impl ProcessedEndpoints {
pub fn new(endpoints: Vec<Endpoint>) -> Self {
// compute some basic statistics
let load_values: Vec<f64> = endpoints
.iter()
.map(|x| x.data.kv_active_blocks as f64)
.collect();
let load_avg = load_values.iter().copied().sum::<f64>() / load_values.len() as f64;
let variance = load_values
.iter()
.map(|&x| (x - load_avg).powi(2))
.sum::<f64>()
/ load_values.len() as f64;
let load_std = variance.sqrt();
let worker_ids: HashSet<Uuid> = endpoints.iter().map(|x| x.worker_id()).collect();
let worker_ids: Vec<Uuid> = worker_ids.into_iter().collect();
ProcessedEndpoints {
endpoints,
worker_ids,
load_avg,
load_std,
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
pub use crate::kv_router::protocols::ForwardPassMetrics;
use anyhow::Result;
use derive_builder::Builder;
use triton_distributed::pipeline::network::{
ingress::push_endpoint::PushEndpoint,
PushWorkHandler,
};
use triton_distributed::transports::nats::{self, ServiceExt};
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
use tracing as log;
#[derive(Builder)]
pub struct KvRoutedIngress {
#[builder(setter(into))]
pub service_name: String,
#[builder(setter(into))]
pub worker_id: String,
pub nats: nats::Client,
pub service_handler: Arc<dyn PushWorkHandler>,
pub metrics_rx: watch::Receiver<Arc<ForwardPassMetrics>>,
pub cancellation_token: CancellationToken,
}
/// version of crate
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
impl KvRoutedIngress {
pub fn builder() -> KvRoutedIngressBuilder {
KvRoutedIngressBuilder::default()
}
pub async fn start(self) -> Result<()> {
let worker_id = self.worker_id;
log::trace!(
worker_id,
"Starting nats service: {}:{}",
self.service_name,
VERSION
);
let mut metrics_rx = self.metrics_rx;
let worker_id_clone = worker_id.clone();
let service = self
.nats
.client()
.service_builder()
.description("A handy min max service")
.stats_handler(move |name, stats| {
log::debug!(
worker_id = worker_id_clone.as_str(),
"[IN worker?] Stats for service {}: {:?}",
name,
stats
);
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})
.start(self.service_name.as_str(), VERSION)
.await
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
let group = service.group(self.service_name.as_str());
log::trace!(worker_id, "Starting endpoint: {}", worker_id);
// creates an endpoint for the service
let service_endpoint = group
.endpoint(worker_id.clone())
.await
.map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?;
let push_endpoint = PushEndpoint::builder()
.service_handler(self.service_handler)
.cancellation_token(self.cancellation_token)
.build()
.map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?;
push_endpoint.start(service_endpoint).await
}
}
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
//! The `triton-llm` crate is a Rust library that provides a set of traits and types for building //! The `triton-llm` crate is a Rust library that provides a set of traits and types for building
//! distributed LLM inference solutions. //! distributed LLM inference solutions.
pub mod engines;
pub mod http; pub mod http;
pub mod kv_router;
pub mod protocols; pub mod protocols;
pub mod types; pub mod types;
pub mod engines;
...@@ -61,8 +61,8 @@ addopts = [ ...@@ -61,8 +61,8 @@ addopts = [
"--ignore-glob=*model.py", "--ignore-glob=*model.py",
# FIXME: Get relative/generic blob paths to work here # FIXME: Get relative/generic blob paths to work here
# Ignore rust<->python bindings until python package is built/installed in environment # Ignore rust<->python bindings until python package is built/installed in environment
"--ignore-glob=/workspace/runtime/rust/python-wheel/python/triton_distributed_rs/*.py", "--ignore-glob=/workspace/python-wheel/python/triton_distributed_rs/*.py",
"--ignore-glob=/workspace/runtime/rust/python-wheel/python/triton_distributed_rs/*.pyi", "--ignore-glob=/workspace/python-wheel/python/triton_distributed_rs/*.pyi",
] ]
xfail_strict = true xfail_strict = true
log_cli_level = "INFO" log_cli_level = "INFO"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment