lib.rs 9.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use async_once_cell::OnceCell as AsyncOnceCell;
use libc::c_char;
use once_cell::sync::OnceCell;
use std::ffi::CStr;
use std::sync::atomic::{AtomicU32, Ordering};

Neelay Shah's avatar
Neelay Shah committed
10
use dynamo_llm::kv_router::{
GuanLuo's avatar
GuanLuo committed
11
    indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher,
12
};
Neelay Shah's avatar
Neelay Shah committed
13
use dynamo_runtime::{DistributedRuntime, Worker};
14
15
16
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
GuanLuo's avatar
GuanLuo committed
17
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
18
19
20
21
22
23
24
25

fn initialize_tracing() {
    // Sets up RUST_LOG environment variable for logging while KV Publishing
    // Example: os.environ["RUST_LOG"] = "debug"
    let subscriber = tracing_subscriber::fmt()
        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
        .finish();

Ryan Olson's avatar
Ryan Olson committed
26
    tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
27

28
    tracing::debug!("Tracing initialized");
29
30
31
}

#[repr(u32)]
Neelay Shah's avatar
Neelay Shah committed
32
pub enum DynamoLlmResult {
33
34
35
36
37
    OK = 0,
    ERR = 1,
}

/// # Safety
GuanLuo's avatar
GuanLuo committed
38
/// the namespace_c_str and component_c_str are passed as pointers to C strings
39
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
40
pub unsafe extern "C" fn dynamo_llm_init(
GuanLuo's avatar
GuanLuo committed
41
42
43
    namespace_c_str: *const c_char,
    component_c_str: *const c_char,
    worker_id: i64,
44
    kv_block_size: u32,
Neelay Shah's avatar
Neelay Shah committed
45
) -> DynamoLlmResult {
46
47
48
49
50
    initialize_tracing();
    let wk = match WK.get_or_try_init(Worker::from_settings) {
        Ok(wk) => wk.clone(),
        Err(e) => {
            eprintln!("Failed to initialize runtime: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
51
            return DynamoLlmResult::ERR;
52
53
54
55
56
57
58
59
60
61
62
63
64
        }
    };
    let rt = wk.runtime();
    let secondary = rt.secondary().clone();
    let result = secondary.block_on(async {
        // Initialize the distributed runtime
        match DRT
            .get_or_try_init(async { DistributedRuntime::from_settings(rt.clone()).await })
            .await
        {
            Ok(_) => Ok(()),
            Err(e) => {
                eprintln!("Failed to initialize distributed runtime: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
65
                Err(DynamoLlmResult::ERR)
66
67
68
            }
        }
    });
GuanLuo's avatar
GuanLuo committed
69
    let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
70
71
72
        Ok(s) => s.to_string(),
        Err(e) => {
            eprintln!("Failed to convert C string to Rust string: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
73
            return DynamoLlmResult::ERR;
74
75
76
        }
    };

GuanLuo's avatar
GuanLuo committed
77
78
    let component = match unsafe { CStr::from_ptr(component_c_str) }.to_str() {
        Ok(s) => s.to_string(),
79
80
        Err(e) => {
            eprintln!("Failed to convert C string to Rust string: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
81
            return DynamoLlmResult::ERR;
82
83
84
85
        }
    };

    match result {
86
        Ok(_) => match KV_PUB.get_or_try_init(move || {
87
            dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size)
88
        }) {
Neelay Shah's avatar
Neelay Shah committed
89
            Ok(_) => DynamoLlmResult::OK,
90
91
            Err(e) => {
                eprintln!("Failed to initialize distributed runtime: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
92
                DynamoLlmResult::ERR
93
94
95
96
97
98
            }
        },
        Err(e) => e,
    }
}

99
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
100
pub extern "C" fn dynamo_llm_shutdown() -> DynamoLlmResult {
101
102
103
104
    let wk = match WK.get() {
        Some(wk) => wk,
        None => {
            eprintln!("Runtime not initialized");
Neelay Shah's avatar
Neelay Shah committed
105
            return DynamoLlmResult::ERR;
106
107
108
109
110
        }
    };

    wk.runtime().shutdown();

Neelay Shah's avatar
Neelay Shah committed
111
    DynamoLlmResult::OK
112
113
}

114
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
115
116
pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult {
    DynamoLlmResult::OK
117
118
119
120
}

// instantiate a kv publisher
// this will bring up the task to publish and the channels to await publishing events
Neelay Shah's avatar
Neelay Shah committed
121
122
// the [`dynamo_kv_publish_store_event`] call will use a handle to the publisher to send events
// store and the [`dynamo_kv_event_create_removed`] will create remove events
123
124
125
// these call mus be driving by external c++ threads that are consuming the kv events from the
// c++ executor api

Neelay Shah's avatar
Neelay Shah committed
126
fn dynamo_create_kv_publisher(
GuanLuo's avatar
GuanLuo committed
127
128
129
    namespace: String,
    component: String,
    worker_id: i64,
130
    kv_block_size: u32,
GuanLuo's avatar
GuanLuo committed
131
) -> Result<KvEventPublisher, anyhow::Error> {
132
    tracing::info!("Creating KV Publisher for model: {}", component);
133
134
135
136
137
    match DRT
        .get()
        .ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
    {
        Ok(drt) => {
GuanLuo's avatar
GuanLuo committed
138
            let backend = drt.namespace(namespace)?.component(component)?;
139
            KvEventPublisher::new(backend, worker_id, kv_block_size, None)
140
141
142
143
144
145
146
147
148
        }
        Err(e) => Err(e),
    }
}

fn kv_event_create_stored_block_from_parts(
    block_hash: u64,
    token_ids: *const u32,
    num_tokens: usize,
149
    kv_block_size: u32,
150
151
    _lora_id: u64,
) -> KvCacheStoredBlockData {
152
153
154
155
    let tokens_hash = compute_block_hash_for_seq(
        unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
        kv_block_size,
    )[0];
156
157
158
159
160
161
162
163
    KvCacheStoredBlockData {
        block_hash: ExternalSequenceBlockHash(block_hash),
        tokens_hash,
    }
}
static WARN_COUNT: AtomicU32 = AtomicU32::new(0);

fn kv_event_create_stored_from_parts(
164
    kv_params: DynamoKvStoredEventParams,
165
    kv_block_size: u32,
166
167
168
169
) -> KvCacheEvent {
    let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();

    let mut token_offset: usize = 0;
170
171
172
173
174
175
176
177
178
    for block_idx in 0..kv_params.num_blocks {
        let block_hash = unsafe { *kv_params.block_ids.offset(block_idx.try_into().unwrap()) };
        let tokens = unsafe { kv_params.token_ids.offset(token_offset.try_into().unwrap()) };
        let num_toks = unsafe {
            *kv_params
                .num_block_tokens
                .offset(block_idx.try_into().unwrap())
        };

179
        if num_toks != (kv_block_size as usize) {
Ryan Olson's avatar
Ryan Olson committed
180
181
            if WARN_COUNT
                .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
182
                    if c < 3 { Some(c + 1) } else { None }
Ryan Olson's avatar
Ryan Olson committed
183
184
185
                })
                .is_ok()
            {
186
                tracing::warn!(
187
188
                    "Block not published. Block size must be {} tokens to be published. Block size is: {}",
                    kv_block_size,
Ryan Olson's avatar
Ryan Olson committed
189
190
                    num_toks
                );
191
192
193
194
195
            }
            break;
        }
        token_offset += num_toks;
        blocks.push(kv_event_create_stored_block_from_parts(
196
197
198
199
200
            block_hash,
            tokens,
            num_toks,
            kv_block_size,
            kv_params.lora_id,
201
202
203
204
205
206
        ));
    }

    KvCacheEvent {
        data: KvCacheEventData::Stored(KvCacheStoreData {
            blocks,
207
            parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
208
        }),
209
        event_id: kv_params.event_id,
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    }
}

fn kv_event_create_removed_from_parts(
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
) -> KvCacheEvent {
    let block_hashes: Vec<ExternalSequenceBlockHash> =
        unsafe { std::slice::from_raw_parts(block_ids, num_blocks) }
            .to_vec()
            .iter()
            .map(|&v| ExternalSequenceBlockHash(v))
            .collect();
    KvCacheEvent {
        event_id,
        data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
    }
}

230
231
232
233
234
235
236
237
238
239
pub struct DynamoKvStoredEventParams {
    pub event_id: u64,
    pub token_ids: *const u32,
    pub num_block_tokens: *const usize,
    pub block_ids: *const u64,
    pub num_blocks: usize,
    pub parent_hash: Option<u64>,
    pub lora_id: u64,
}

240
241
242
/// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks
/// has a parent hash or not. nullptr is used to represent no parent hash
243
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
244
pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
245
246
247
248
249
250
251
    event_id: u64,
    token_ids: *const u32,
    num_block_tokens: *const usize,
    block_ids: *const u64,
    num_blocks: usize,
    parent_hash: *const u64,
    lora_id: u64,
Neelay Shah's avatar
Neelay Shah committed
252
) -> DynamoLlmResult {
253
254
255
256
257
258
259
    let parent_hash = {
        if parent_hash.is_null() {
            None
        } else {
            Some(unsafe { *parent_hash })
        }
    };
260
    let kv_params = DynamoKvStoredEventParams {
261
262
263
264
265
266
267
        event_id,
        token_ids,
        num_block_tokens,
        block_ids,
        num_blocks,
        parent_hash,
        lora_id,
268
269
270
    };
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
271
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
272
        Ok(_) => DynamoLlmResult::OK,
273
274
        Err(e) => {
            eprintln!("Error publishing stored kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
275
            DynamoLlmResult::ERR
276
277
278
279
        }
    }
}

280
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
281
pub extern "C" fn dynamo_kv_event_publish_removed(
282
283
284
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
Neelay Shah's avatar
Neelay Shah committed
285
) -> DynamoLlmResult {
286
287
288
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks);
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
289
        Ok(_) => DynamoLlmResult::OK,
290
291
        Err(e) => {
            eprintln!("Error publishing removed kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
292
            DynamoLlmResult::ERR
293
294
295
296
        }
    }
}

297
298
299
300
301
// Need to setup etcd and nats to run these tests
// #[cfg(test)]
// mod tests {
//     use super::*;
//     use std::ffi::CString;
302

303
304
305
306
307
//     #[test]
//     fn test_dynamo_llm_init() {
//         // Create C-compatible strings
//         let namespace = CString::new("test_namespace").unwrap();
//         let component = CString::new("test_component").unwrap();
308

309
310
311
312
313
314
315
316
317
//         // Call the init function
//         let result = unsafe {
//             dynamo_llm_init(
//                 namespace.as_ptr(),
//                 component.as_ptr(),
//                 1,  // worker_id
//                 32, // kv_block_size
//             )
//         };
318

319
//         assert_eq!(result as u32, DynamoLlmResult::OK as u32);
320

321
//         assert!(WK.get().is_some());
322

323
324
325
//         let shutdown_result = dynamo_llm_shutdown();
//         assert_eq!(shutdown_result as u32, DynamoLlmResult::OK as u32);
//     }
326
// }