lib.rs 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use async_once_cell::OnceCell as AsyncOnceCell;
use libc::c_char;
use once_cell::sync::OnceCell;
use std::ffi::CStr;
use std::sync::atomic::{AtomicU32, Ordering};

Neelay Shah's avatar
Neelay Shah committed
22
use dynamo_llm::kv_router::{
GuanLuo's avatar
GuanLuo committed
23
    indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher,
24
};
Neelay Shah's avatar
Neelay Shah committed
25
use dynamo_runtime::{DistributedRuntime, Worker};
26
27
28
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
GuanLuo's avatar
GuanLuo committed
29
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
30
31
32
33
34
35
36
37

fn initialize_tracing() {
    // Sets up RUST_LOG environment variable for logging while KV Publishing
    // Example: os.environ["RUST_LOG"] = "debug"
    let subscriber = tracing_subscriber::fmt()
        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
        .finish();

Ryan Olson's avatar
Ryan Olson committed
38
    tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
39

40
    tracing::debug!("Tracing initialized");
41
42
43
}

#[repr(u32)]
Neelay Shah's avatar
Neelay Shah committed
44
pub enum DynamoLlmResult {
45
46
47
48
49
    OK = 0,
    ERR = 1,
}

/// # Safety
GuanLuo's avatar
GuanLuo committed
50
/// the namespace_c_str and component_c_str are passed as pointers to C strings
51
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
52
pub unsafe extern "C" fn dynamo_llm_init(
GuanLuo's avatar
GuanLuo committed
53
54
55
    namespace_c_str: *const c_char,
    component_c_str: *const c_char,
    worker_id: i64,
56
    kv_block_size: u32,
Neelay Shah's avatar
Neelay Shah committed
57
) -> DynamoLlmResult {
58
59
60
61
62
    initialize_tracing();
    let wk = match WK.get_or_try_init(Worker::from_settings) {
        Ok(wk) => wk.clone(),
        Err(e) => {
            eprintln!("Failed to initialize runtime: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
63
            return DynamoLlmResult::ERR;
64
65
66
67
68
69
70
71
72
73
74
75
76
        }
    };
    let rt = wk.runtime();
    let secondary = rt.secondary().clone();
    let result = secondary.block_on(async {
        // Initialize the distributed runtime
        match DRT
            .get_or_try_init(async { DistributedRuntime::from_settings(rt.clone()).await })
            .await
        {
            Ok(_) => Ok(()),
            Err(e) => {
                eprintln!("Failed to initialize distributed runtime: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
77
                Err(DynamoLlmResult::ERR)
78
79
80
            }
        }
    });
GuanLuo's avatar
GuanLuo committed
81
    let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
82
83
84
        Ok(s) => s.to_string(),
        Err(e) => {
            eprintln!("Failed to convert C string to Rust string: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
85
            return DynamoLlmResult::ERR;
86
87
88
        }
    };

GuanLuo's avatar
GuanLuo committed
89
90
    let component = match unsafe { CStr::from_ptr(component_c_str) }.to_str() {
        Ok(s) => s.to_string(),
91
92
        Err(e) => {
            eprintln!("Failed to convert C string to Rust string: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
93
            return DynamoLlmResult::ERR;
94
95
96
97
        }
    };

    match result {
98
        Ok(_) => match KV_PUB.get_or_try_init(move || {
99
            dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size)
100
        }) {
Neelay Shah's avatar
Neelay Shah committed
101
            Ok(_) => DynamoLlmResult::OK,
102
103
            Err(e) => {
                eprintln!("Failed to initialize distributed runtime: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
104
                DynamoLlmResult::ERR
105
106
107
108
109
110
            }
        },
        Err(e) => e,
    }
}

111
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
112
pub extern "C" fn dynamo_llm_shutdown() -> DynamoLlmResult {
113
114
115
116
    let wk = match WK.get() {
        Some(wk) => wk,
        None => {
            eprintln!("Runtime not initialized");
Neelay Shah's avatar
Neelay Shah committed
117
            return DynamoLlmResult::ERR;
118
119
120
121
122
        }
    };

    wk.runtime().shutdown();

Neelay Shah's avatar
Neelay Shah committed
123
    DynamoLlmResult::OK
124
125
}

126
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
127
128
pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult {
    DynamoLlmResult::OK
129
130
131
132
}

// instantiate a kv publisher
// this will bring up the task to publish and the channels to await publishing events
Neelay Shah's avatar
Neelay Shah committed
133
134
// the [`dynamo_kv_publish_store_event`] call will use a handle to the publisher to send events
// store and the [`dynamo_kv_event_create_removed`] will create remove events
135
136
137
// these call mus be driving by external c++ threads that are consuming the kv events from the
// c++ executor api

Neelay Shah's avatar
Neelay Shah committed
138
fn dynamo_create_kv_publisher(
GuanLuo's avatar
GuanLuo committed
139
140
141
    namespace: String,
    component: String,
    worker_id: i64,
142
    kv_block_size: u32,
GuanLuo's avatar
GuanLuo committed
143
) -> Result<KvEventPublisher, anyhow::Error> {
144
    tracing::info!("Creating KV Publisher for model: {}", component);
145
146
147
148
149
    match DRT
        .get()
        .ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
    {
        Ok(drt) => {
GuanLuo's avatar
GuanLuo committed
150
            let backend = drt.namespace(namespace)?.component(component)?;
151
            KvEventPublisher::new(backend, worker_id, kv_block_size, None)
152
153
154
155
156
157
158
159
160
        }
        Err(e) => Err(e),
    }
}

fn kv_event_create_stored_block_from_parts(
    block_hash: u64,
    token_ids: *const u32,
    num_tokens: usize,
161
    kv_block_size: u32,
162
163
    _lora_id: u64,
) -> KvCacheStoredBlockData {
164
165
166
167
    let tokens_hash = compute_block_hash_for_seq(
        unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
        kv_block_size,
    )[0];
168
169
170
171
172
173
174
175
    KvCacheStoredBlockData {
        block_hash: ExternalSequenceBlockHash(block_hash),
        tokens_hash,
    }
}
static WARN_COUNT: AtomicU32 = AtomicU32::new(0);

fn kv_event_create_stored_from_parts(
176
    kv_params: DynamoKvStoredEventParams,
177
    kv_block_size: u32,
178
179
180
181
) -> KvCacheEvent {
    let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();

    let mut token_offset: usize = 0;
182
183
184
185
186
187
188
189
190
    for block_idx in 0..kv_params.num_blocks {
        let block_hash = unsafe { *kv_params.block_ids.offset(block_idx.try_into().unwrap()) };
        let tokens = unsafe { kv_params.token_ids.offset(token_offset.try_into().unwrap()) };
        let num_toks = unsafe {
            *kv_params
                .num_block_tokens
                .offset(block_idx.try_into().unwrap())
        };

191
        if num_toks != (kv_block_size as usize) {
Ryan Olson's avatar
Ryan Olson committed
192
193
            if WARN_COUNT
                .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
194
                    if c < 3 { Some(c + 1) } else { None }
Ryan Olson's avatar
Ryan Olson committed
195
196
197
                })
                .is_ok()
            {
198
                tracing::warn!(
199
200
                    "Block not published. Block size must be {} tokens to be published. Block size is: {}",
                    kv_block_size,
Ryan Olson's avatar
Ryan Olson committed
201
202
                    num_toks
                );
203
204
205
206
207
            }
            break;
        }
        token_offset += num_toks;
        blocks.push(kv_event_create_stored_block_from_parts(
208
209
210
211
212
            block_hash,
            tokens,
            num_toks,
            kv_block_size,
            kv_params.lora_id,
213
214
215
216
217
218
        ));
    }

    KvCacheEvent {
        data: KvCacheEventData::Stored(KvCacheStoreData {
            blocks,
219
            parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
220
        }),
221
        event_id: kv_params.event_id,
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    }
}

fn kv_event_create_removed_from_parts(
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
) -> KvCacheEvent {
    let block_hashes: Vec<ExternalSequenceBlockHash> =
        unsafe { std::slice::from_raw_parts(block_ids, num_blocks) }
            .to_vec()
            .iter()
            .map(|&v| ExternalSequenceBlockHash(v))
            .collect();
    KvCacheEvent {
        event_id,
        data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
    }
}

242
243
244
245
246
247
248
249
250
251
pub struct DynamoKvStoredEventParams {
    pub event_id: u64,
    pub token_ids: *const u32,
    pub num_block_tokens: *const usize,
    pub block_ids: *const u64,
    pub num_blocks: usize,
    pub parent_hash: Option<u64>,
    pub lora_id: u64,
}

252
253
254
/// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks
/// has a parent hash or not. nullptr is used to represent no parent hash
255
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
256
pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
257
258
259
260
261
262
263
    event_id: u64,
    token_ids: *const u32,
    num_block_tokens: *const usize,
    block_ids: *const u64,
    num_blocks: usize,
    parent_hash: *const u64,
    lora_id: u64,
Neelay Shah's avatar
Neelay Shah committed
264
) -> DynamoLlmResult {
265
266
267
268
269
270
271
    let parent_hash = {
        if parent_hash.is_null() {
            None
        } else {
            Some(unsafe { *parent_hash })
        }
    };
272
    let kv_params = DynamoKvStoredEventParams {
273
274
275
276
277
278
279
        event_id,
        token_ids,
        num_block_tokens,
        block_ids,
        num_blocks,
        parent_hash,
        lora_id,
280
281
282
    };
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
283
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
284
        Ok(_) => DynamoLlmResult::OK,
285
286
        Err(e) => {
            eprintln!("Error publishing stored kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
287
            DynamoLlmResult::ERR
288
289
290
291
        }
    }
}

292
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
293
pub extern "C" fn dynamo_kv_event_publish_removed(
294
295
296
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
Neelay Shah's avatar
Neelay Shah committed
297
) -> DynamoLlmResult {
298
299
300
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks);
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
301
        Ok(_) => DynamoLlmResult::OK,
302
303
        Err(e) => {
            eprintln!("Error publishing removed kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
304
            DynamoLlmResult::ERR
305
306
307
308
        }
    }
}

309
310
311
312
313
// Need to setup etcd and nats to run these tests
// #[cfg(test)]
// mod tests {
//     use super::*;
//     use std::ffi::CString;
314

315
316
317
318
319
//     #[test]
//     fn test_dynamo_llm_init() {
//         // Create C-compatible strings
//         let namespace = CString::new("test_namespace").unwrap();
//         let component = CString::new("test_component").unwrap();
320

321
322
323
324
325
326
327
328
329
//         // Call the init function
//         let result = unsafe {
//             dynamo_llm_init(
//                 namespace.as_ptr(),
//                 component.as_ptr(),
//                 1,  // worker_id
//                 32, // kv_block_size
//             )
//         };
330

331
//         assert_eq!(result as u32, DynamoLlmResult::OK as u32);
332

333
//         assert!(WK.get().is_some());
334

335
336
337
//         let shutdown_result = dynamo_llm_shutdown();
//         assert_eq!(shutdown_result as u32, DynamoLlmResult::OK as u32);
//     }
338
// }