// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use async_once_cell::OnceCell as AsyncOnceCell; use libc::c_char; use once_cell::sync::OnceCell; use std::ffi::CStr; use std::sync::atomic::{AtomicU32, Ordering}; use dynamo_llm::kv_router::{ indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher, }; use dynamo_runtime::{DistributedRuntime, Worker}; static WK: OnceCell = OnceCell::new(); static DRT: AsyncOnceCell = AsyncOnceCell::new(); // [FIXME] shouldn't the publisher be instance passing between API calls? static KV_PUB: OnceCell = OnceCell::new(); fn initialize_tracing() { // Sets up RUST_LOG environment variable for logging while KV Publishing // Example: os.environ["RUST_LOG"] = "debug" let subscriber = tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .finish(); tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); tracing::debug!("Tracing initialized"); } #[repr(u32)] pub enum DynamoLlmResult { OK = 0, ERR = 1, } /// # Safety /// the namespace_c_str and component_c_str are passed as pointers to C strings #[no_mangle] pub unsafe extern "C" fn dynamo_llm_init( namespace_c_str: *const c_char, component_c_str: *const c_char, worker_id: i64, kv_block_size: u32, ) -> DynamoLlmResult { initialize_tracing(); let wk = match WK.get_or_try_init(Worker::from_settings) { Ok(wk) => wk.clone(), Err(e) => { eprintln!("Failed to initialize runtime: {:?}", e); return DynamoLlmResult::ERR; } }; let rt = wk.runtime(); let secondary = rt.secondary().clone(); let result = secondary.block_on(async { // Initialize the distributed runtime match DRT .get_or_try_init(async { DistributedRuntime::from_settings(rt.clone()).await }) .await { Ok(_) => Ok(()), Err(e) => { eprintln!("Failed to initialize distributed runtime: {:?}", e); Err(DynamoLlmResult::ERR) } } }); let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() { Ok(s) => s.to_string(), Err(e) => { eprintln!("Failed to convert C string to Rust string: {:?}", e); return DynamoLlmResult::ERR; } }; let component = match unsafe { CStr::from_ptr(component_c_str) }.to_str() { Ok(s) => s.to_string(), Err(e) => { eprintln!("Failed to convert C string to Rust string: {:?}", e); return DynamoLlmResult::ERR; } }; match result { Ok(_) => match KV_PUB.get_or_try_init(move || { dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size as usize) }) { Ok(_) => DynamoLlmResult::OK, Err(e) => { eprintln!("Failed to initialize distributed runtime: {:?}", e); DynamoLlmResult::ERR } }, Err(e) => e, } } #[no_mangle] pub extern "C" fn dynamo_llm_shutdown() -> DynamoLlmResult { let wk = match WK.get() { Some(wk) => wk, None => { eprintln!("Runtime not initialized"); return DynamoLlmResult::ERR; } }; wk.runtime().shutdown(); DynamoLlmResult::OK } #[no_mangle] pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult { DynamoLlmResult::OK } // instantiate a kv publisher // this will bring up the task to publish and the channels to await publishing events // the [`dynamo_kv_publish_store_event`] call will use a handle to the publisher to send events // store and the [`dynamo_kv_event_create_removed`] will create remove events // these call mus be driving by external c++ threads that are consuming the kv events from the // c++ executor api fn dynamo_create_kv_publisher( namespace: String, component: String, worker_id: i64, kv_block_size: usize, ) -> Result { tracing::info!("Creating KV Publisher for model: {}", component); match DRT .get() .ok_or(anyhow::Error::msg("Could not get Distributed Runtime")) { Ok(drt) => { let backend = drt.namespace(namespace)?.component(component)?; KvEventPublisher::new(drt.clone(), backend, worker_id, kv_block_size) } Err(e) => Err(e), } } fn kv_event_create_stored_block_from_parts( block_hash: u64, token_ids: *const u32, num_tokens: usize, kv_block_size: usize, _lora_id: u64, ) -> KvCacheStoredBlockData { let tokens_hash = compute_block_hash_for_seq( unsafe { std::slice::from_raw_parts(token_ids, num_tokens) }, kv_block_size, )[0]; KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(block_hash), tokens_hash, } } static WARN_COUNT: AtomicU32 = AtomicU32::new(0); fn kv_event_create_stored_from_parts( kv_params: DynamoKvStoredEventParams, kv_block_size: usize, ) -> KvCacheEvent { let mut blocks: Vec = Vec::new(); let mut token_offset: usize = 0; for block_idx in 0..kv_params.num_blocks { let block_hash = unsafe { *kv_params.block_ids.offset(block_idx.try_into().unwrap()) }; let tokens = unsafe { kv_params.token_ids.offset(token_offset.try_into().unwrap()) }; let num_toks = unsafe { *kv_params .num_block_tokens .offset(block_idx.try_into().unwrap()) }; if num_toks != kv_block_size { if WARN_COUNT .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| { if c < 3 { Some(c + 1) } else { None } }) .is_ok() { tracing::warn!( "Block not published. Block size must be {} tokens to be published. Block size is: {}", kv_block_size, num_toks ); } break; } token_offset += num_toks; blocks.push(kv_event_create_stored_block_from_parts( block_hash, tokens, num_toks, kv_block_size, kv_params.lora_id, )); } KvCacheEvent { data: KvCacheEventData::Stored(KvCacheStoreData { blocks, parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash), }), event_id: kv_params.event_id, } } fn kv_event_create_removed_from_parts( event_id: u64, block_ids: *const u64, num_blocks: usize, ) -> KvCacheEvent { let block_hashes: Vec = unsafe { std::slice::from_raw_parts(block_ids, num_blocks) } .to_vec() .iter() .map(|&v| ExternalSequenceBlockHash(v)) .collect(); KvCacheEvent { event_id, data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }), } } pub struct DynamoKvStoredEventParams { pub event_id: u64, pub token_ids: *const u32, pub num_block_tokens: *const usize, pub block_ids: *const u64, pub num_blocks: usize, pub parent_hash: Option, pub lora_id: u64, } /// # Safety /// parent_hash is passed as pointer to indicate whether the blocks /// has a parent hash or not. nullptr is used to represent no parent hash #[no_mangle] pub unsafe extern "C" fn dynamo_kv_event_publish_stored( event_id: u64, token_ids: *const u32, num_block_tokens: *const usize, block_ids: *const u64, num_blocks: usize, parent_hash: *const u64, lora_id: u64, ) -> DynamoLlmResult { let parent_hash = { if parent_hash.is_null() { None } else { Some(unsafe { *parent_hash }) } }; let kv_params = DynamoKvStoredEventParams { event_id, token_ids, num_block_tokens, block_ids, num_blocks, parent_hash, lora_id, }; let publisher = KV_PUB.get().unwrap(); let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size()); match publisher.publish(event) { Ok(_) => DynamoLlmResult::OK, Err(e) => { eprintln!("Error publishing stored kv event {:?}", e); DynamoLlmResult::ERR } } } #[no_mangle] pub extern "C" fn dynamo_kv_event_publish_removed( event_id: u64, block_ids: *const u64, num_blocks: usize, ) -> DynamoLlmResult { let publisher = KV_PUB.get().unwrap(); let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks); match publisher.publish(event) { Ok(_) => DynamoLlmResult::OK, Err(e) => { eprintln!("Error publishing removed kv event {:?}", e); DynamoLlmResult::ERR } } } // Need to setup etcd and nats to run these tests // #[cfg(test)] // mod tests { // use super::*; // use std::ffi::CString; // #[test] // fn test_dynamo_llm_init() { // // Create C-compatible strings // let namespace = CString::new("test_namespace").unwrap(); // let component = CString::new("test_component").unwrap(); // // Call the init function // let result = unsafe { // dynamo_llm_init( // namespace.as_ptr(), // component.as_ptr(), // 1, // worker_id // 32, // kv_block_size // ) // }; // assert_eq!(result as u32, DynamoLlmResult::OK as u32); // assert!(WK.get().is_some()); // let shutdown_result = dynamo_llm_shutdown(); // assert_eq!(shutdown_result as u32, DynamoLlmResult::OK as u32); // } // }