"lib/llm/src/lib.rs" did not exist on "9d6643b7a59220fc4f3ef599c002241dd0bf9965"
lib.rs 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use async_once_cell::OnceCell as AsyncOnceCell;
use libc::c_char;
use once_cell::sync::OnceCell;
use std::ffi::CStr;
use std::sync::atomic::{AtomicU32, Ordering};

Neelay Shah's avatar
Neelay Shah committed
22
use dynamo_llm::kv_router::{
GuanLuo's avatar
GuanLuo committed
23
    indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher,
24
};
Neelay Shah's avatar
Neelay Shah committed
25
use dynamo_runtime::{DistributedRuntime, Worker};
26
27
28
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
GuanLuo's avatar
GuanLuo committed
29
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
30
31
32
33
34
35
36
37

fn initialize_tracing() {
    // Sets up RUST_LOG environment variable for logging while KV Publishing
    // Example: os.environ["RUST_LOG"] = "debug"
    let subscriber = tracing_subscriber::fmt()
        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
        .finish();

Ryan Olson's avatar
Ryan Olson committed
38
    tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
39

40
    tracing::debug!("Tracing initialized");
41
42
43
}

#[repr(u32)]
Neelay Shah's avatar
Neelay Shah committed
44
pub enum DynamoLlmResult {
45
46
47
48
49
    OK = 0,
    ERR = 1,
}

/// # Safety
GuanLuo's avatar
GuanLuo committed
50
/// the namespace_c_str and component_c_str are passed as pointers to C strings
51
#[no_mangle]
Neelay Shah's avatar
Neelay Shah committed
52
pub unsafe extern "C" fn dynamo_llm_init(
GuanLuo's avatar
GuanLuo committed
53
54
55
    namespace_c_str: *const c_char,
    component_c_str: *const c_char,
    worker_id: i64,
56
    kv_block_size: u32,
Neelay Shah's avatar
Neelay Shah committed
57
) -> DynamoLlmResult {
58
59
60
61
62
    initialize_tracing();
    let wk = match WK.get_or_try_init(Worker::from_settings) {
        Ok(wk) => wk.clone(),
        Err(e) => {
            eprintln!("Failed to initialize runtime: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
63
            return DynamoLlmResult::ERR;
64
65
66
67
68
69
70
71
72
73
74
75
76
        }
    };
    let rt = wk.runtime();
    let secondary = rt.secondary().clone();
    let result = secondary.block_on(async {
        // Initialize the distributed runtime
        match DRT
            .get_or_try_init(async { DistributedRuntime::from_settings(rt.clone()).await })
            .await
        {
            Ok(_) => Ok(()),
            Err(e) => {
                eprintln!("Failed to initialize distributed runtime: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
77
                Err(DynamoLlmResult::ERR)
78
79
80
            }
        }
    });
GuanLuo's avatar
GuanLuo committed
81
    let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
82
83
84
        Ok(s) => s.to_string(),
        Err(e) => {
            eprintln!("Failed to convert C string to Rust string: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
85
            return DynamoLlmResult::ERR;
86
87
88
        }
    };

GuanLuo's avatar
GuanLuo committed
89
90
    let component = match unsafe { CStr::from_ptr(component_c_str) }.to_str() {
        Ok(s) => s.to_string(),
91
92
        Err(e) => {
            eprintln!("Failed to convert C string to Rust string: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
93
            return DynamoLlmResult::ERR;
94
95
96
97
        }
    };

    match result {
98
99
100
        Ok(_) => match KV_PUB.get_or_try_init(move || {
            dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size as usize)
        }) {
Neelay Shah's avatar
Neelay Shah committed
101
            Ok(_) => DynamoLlmResult::OK,
102
103
            Err(e) => {
                eprintln!("Failed to initialize distributed runtime: {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
104
                DynamoLlmResult::ERR
105
106
107
108
109
110
111
            }
        },
        Err(e) => e,
    }
}

#[no_mangle]
Neelay Shah's avatar
Neelay Shah committed
112
pub extern "C" fn dynamo_llm_shutdown() -> DynamoLlmResult {
113
114
115
116
    let wk = match WK.get() {
        Some(wk) => wk,
        None => {
            eprintln!("Runtime not initialized");
Neelay Shah's avatar
Neelay Shah committed
117
            return DynamoLlmResult::ERR;
118
119
120
121
122
        }
    };

    wk.runtime().shutdown();

Neelay Shah's avatar
Neelay Shah committed
123
    DynamoLlmResult::OK
124
125
126
}

#[no_mangle]
Neelay Shah's avatar
Neelay Shah committed
127
128
pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult {
    DynamoLlmResult::OK
129
130
131
132
}

// instantiate a kv publisher
// this will bring up the task to publish and the channels to await publishing events
Neelay Shah's avatar
Neelay Shah committed
133
134
// the [`dynamo_kv_publish_store_event`] call will use a handle to the publisher to send events
// store and the [`dynamo_kv_event_create_removed`] will create remove events
135
136
137
// these call mus be driving by external c++ threads that are consuming the kv events from the
// c++ executor api

Neelay Shah's avatar
Neelay Shah committed
138
fn dynamo_create_kv_publisher(
GuanLuo's avatar
GuanLuo committed
139
140
141
    namespace: String,
    component: String,
    worker_id: i64,
142
    kv_block_size: usize,
GuanLuo's avatar
GuanLuo committed
143
) -> Result<KvEventPublisher, anyhow::Error> {
144
    tracing::info!("Creating KV Publisher for model: {}", component);
145
146
147
148
149
    match DRT
        .get()
        .ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
    {
        Ok(drt) => {
GuanLuo's avatar
GuanLuo committed
150
            let backend = drt.namespace(namespace)?.component(component)?;
151
            KvEventPublisher::new(drt.clone(), backend, worker_id, kv_block_size)
152
153
154
155
156
157
158
159
160
        }
        Err(e) => Err(e),
    }
}

fn kv_event_create_stored_block_from_parts(
    block_hash: u64,
    token_ids: *const u32,
    num_tokens: usize,
161
    kv_block_size: usize,
162
163
    _lora_id: u64,
) -> KvCacheStoredBlockData {
164
165
166
167
    let tokens_hash = compute_block_hash_for_seq(
        unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
        kv_block_size,
    )[0];
168
169
170
171
172
173
174
175
    KvCacheStoredBlockData {
        block_hash: ExternalSequenceBlockHash(block_hash),
        tokens_hash,
    }
}
static WARN_COUNT: AtomicU32 = AtomicU32::new(0);

fn kv_event_create_stored_from_parts(
176
177
    kv_params: DynamoKvStoredEventParams,
    kv_block_size: usize,
178
179
180
181
) -> KvCacheEvent {
    let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();

    let mut token_offset: usize = 0;
182
183
184
185
186
187
188
189
190
191
    for block_idx in 0..kv_params.num_blocks {
        let block_hash = unsafe { *kv_params.block_ids.offset(block_idx.try_into().unwrap()) };
        let tokens = unsafe { kv_params.token_ids.offset(token_offset.try_into().unwrap()) };
        let num_toks = unsafe {
            *kv_params
                .num_block_tokens
                .offset(block_idx.try_into().unwrap())
        };

        if num_toks != kv_block_size {
Ryan Olson's avatar
Ryan Olson committed
192
193
194
195
196
197
198
199
200
201
            if WARN_COUNT
                .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
                    if c < 3 {
                        Some(c + 1)
                    } else {
                        None
                    }
                })
                .is_ok()
            {
202
                tracing::warn!(
203
204
                    "Block not published. Block size must be {} tokens to be published. Block size is: {}",
                    kv_block_size,
Ryan Olson's avatar
Ryan Olson committed
205
206
                    num_toks
                );
207
208
209
210
211
            }
            break;
        }
        token_offset += num_toks;
        blocks.push(kv_event_create_stored_block_from_parts(
212
213
214
215
216
            block_hash,
            tokens,
            num_toks,
            kv_block_size,
            kv_params.lora_id,
217
218
219
220
221
222
        ));
    }

    KvCacheEvent {
        data: KvCacheEventData::Stored(KvCacheStoreData {
            blocks,
223
            parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
224
        }),
225
        event_id: kv_params.event_id,
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    }
}

fn kv_event_create_removed_from_parts(
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
) -> KvCacheEvent {
    let block_hashes: Vec<ExternalSequenceBlockHash> =
        unsafe { std::slice::from_raw_parts(block_ids, num_blocks) }
            .to_vec()
            .iter()
            .map(|&v| ExternalSequenceBlockHash(v))
            .collect();
    KvCacheEvent {
        event_id,
        data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
    }
}

246
247
248
249
250
251
252
253
254
255
pub struct DynamoKvStoredEventParams {
    pub event_id: u64,
    pub token_ids: *const u32,
    pub num_block_tokens: *const usize,
    pub block_ids: *const u64,
    pub num_blocks: usize,
    pub parent_hash: Option<u64>,
    pub lora_id: u64,
}

256
257
258
259
/// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks
/// has a parent hash or not. nullptr is used to represent no parent hash
#[no_mangle]
Neelay Shah's avatar
Neelay Shah committed
260
pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
261
262
263
264
265
266
267
    event_id: u64,
    token_ids: *const u32,
    num_block_tokens: *const usize,
    block_ids: *const u64,
    num_blocks: usize,
    parent_hash: *const u64,
    lora_id: u64,
Neelay Shah's avatar
Neelay Shah committed
268
) -> DynamoLlmResult {
269
270
271
272
273
274
275
    let parent_hash = {
        if parent_hash.is_null() {
            None
        } else {
            Some(unsafe { *parent_hash })
        }
    };
276
    let kv_params = DynamoKvStoredEventParams {
277
278
279
280
281
282
283
        event_id,
        token_ids,
        num_block_tokens,
        block_ids,
        num_blocks,
        parent_hash,
        lora_id,
284
285
286
    };
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
287
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
288
        Ok(_) => DynamoLlmResult::OK,
289
290
        Err(e) => {
            eprintln!("Error publishing stored kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
291
            DynamoLlmResult::ERR
292
293
294
295
296
        }
    }
}

#[no_mangle]
Neelay Shah's avatar
Neelay Shah committed
297
pub extern "C" fn dynamo_kv_event_publish_removed(
298
299
300
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
Neelay Shah's avatar
Neelay Shah committed
301
) -> DynamoLlmResult {
302
303
304
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks);
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
305
        Ok(_) => DynamoLlmResult::OK,
306
307
        Err(e) => {
            eprintln!("Error publishing removed kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
308
            DynamoLlmResult::ERR
309
310
311
312
        }
    }
}

313
314
315
316
317
// Need to setup etcd and nats to run these tests
// #[cfg(test)]
// mod tests {
//     use super::*;
//     use std::ffi::CString;
318

319
320
321
322
323
//     #[test]
//     fn test_dynamo_llm_init() {
//         // Create C-compatible strings
//         let namespace = CString::new("test_namespace").unwrap();
//         let component = CString::new("test_component").unwrap();
324

325
326
327
328
329
330
331
332
333
//         // Call the init function
//         let result = unsafe {
//             dynamo_llm_init(
//                 namespace.as_ptr(),
//                 component.as_ptr(),
//                 1,  // worker_id
//                 32, // kv_block_size
//             )
//         };
334

335
//         assert_eq!(result as u32, DynamoLlmResult::OK as u32);
336

337
//         assert!(WK.get().is_some());
338

339
340
341
//         let shutdown_result = dynamo_llm_shutdown();
//         assert_eq!(shutdown_result as u32, DynamoLlmResult::OK as u32);
//     }
342
// }