lib.rs 2.2 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
pub use backend::{GenerationContext, TensorRtLlmBackend};

mod backend;
pub mod errors;

#[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi {

    /// Struct used as shared type between rust and C++ to represent the result
    /// of a single decoding iteration
    pub struct GenerationStep {
        token_id: u32,
        log_prob: f32,
        is_final: bool,
        has_error: bool,
        error_msg: String,
    }

    extern "Rust" {
        type GenerationContext;
    }

    unsafe extern "C++" {
        include!("backends/trtllm/src/ffi.cpp");

        /// Represent an instance of the underlying TensorRT-LLM backend
        type TensorRtLlmBackendImpl;

        /// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
        ///
        /// # Arguments
        ///
        /// * `engine_folder`: Path to the folder containing all the TRTLLM engines
        /// * `executor_worker`: Path to the TRTLLM executor worker
        ///
        /// returns: <unknown>
        ///
        /// # Examples
        ///
        /// ```
        ///
        /// ```
        #[rust_name = "create_tensorrt_llm_backend"]
        fn CreateTensorRtLlmBackend(
            engine_folder: &str,
            executor_worker: &str,
        ) -> UniquePtr<TensorRtLlmBackendImpl>;

        // #[rust_name = "is_ready"]
        // fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;

        #[rust_name = "num_responses_ready"]
        fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;

        #[rust_name = "submit"]
        fn Submit(
            self: Pin<&mut TensorRtLlmBackendImpl>,
            tokens: &[u32],
            top_k: i32,
            top_p: f32,
            temperature: f32,
            repetition_penalty: f32,
            frequency_penalty: f32,
            seed: u64,
        ) -> u64;

        #[rust_name = "stream_tokens"]
        unsafe fn StreamTokens(
            self: Pin<&mut TensorRtLlmBackendImpl>,
            request_id: u64,
            ctx: *mut GenerationContext,
            cb: unsafe fn(*mut GenerationContext, GenerationStep),
        ) -> usize;

        // #[rust_name = "shutdown"]
        // fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
    }
}