db.rs 5.3 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
/// This code is massively inspired by Tokio mini-redis
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
use crate::{GenerateParameters, GenerateRequest};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
3
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
4
use parking_lot::Mutex;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
5
6
use std::collections::BTreeMap;
use std::sync::Arc;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
7
use std::time::Duration;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
8
use tokio::sync::oneshot::Sender;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
9
use tokio::time::Instant;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
10

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
11
/// Database entry
Olivier Dehaene's avatar
Olivier Dehaene committed
12
13
#[derive(Debug)]
pub(crate) struct Entry {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
    /// Request
Olivier Dehaene's avatar
Olivier Dehaene committed
15
    pub request: GenerateRequest,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
16
    /// Response sender to communicate between the Batcher and the batching_task
Olivier Dehaene's avatar
Olivier Dehaene committed
17
    pub response_tx: Sender<Result<String, ClientError>>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
    /// Number of tokens in the input
Olivier Dehaene's avatar
Olivier Dehaene committed
19
    pub input_length: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
20
21
    /// Instant when this entry was created
    pub time: Instant,
Olivier Dehaene's avatar
Olivier Dehaene committed
22
23
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
24
/// Request Database
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25
26
27
28
29
#[derive(Debug, Clone)]
pub(crate) struct Db {
    pub shared: Arc<Shared>,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
/// Shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
31
32
#[derive(Debug)]
pub struct Shared {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
33
    state: Mutex<State>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
34
35
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
36
/// Database State
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
37
38
#[derive(Debug)]
struct State {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
39
    /// Database entries organized in a BTreeMap to be able to iterate over them in order
Olivier Dehaene's avatar
Olivier Dehaene committed
40
    entries: BTreeMap<u64, Entry>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
42
    /// Id of the next entry
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
43
44
    next_id: u64,

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
45
    /// Id of the next batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
46
47
    next_batch_id: u64,

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
48
    /// Start ID of the next batch. Used to iterate inside the entries BTreeMap
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
49
50
51
    next_batch_start_id: u64,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
impl State {
    /// Get the next requests
    fn next_requests(
        &self,
        max_size: usize,
        min_waiting_time: Option<Duration>,
    ) -> Option<(Vec<u64>, Vec<Request>)> {
        // Iterates for max_size over the BTreemap starting from next_batch_start_id
        let mut requests = Vec::new();
        let mut ids = Vec::new();

        for (id, entry) in self
            .entries
            // Start from next_batch_start_id
            .range(self.next_batch_start_id..)
            // Take max_size
            .take(max_size)
        {
            if let Some(min_waiting_time) = min_waiting_time {
                // Only take entries that waited for at least min_waiting_time
                if entry.time.elapsed() < min_waiting_time {
                    // Since entries are ordered, we already know that all following entries won't
                    // satisfy the condition
                    break;
                }
            }

            requests.push(Request {
                id: *id,
                inputs: entry.request.inputs.clone(),
                input_length: entry.input_length as u32,
                parameters: Some(LogitsWarperParameters::from(
                    entry.request.parameters.clone(),
                )),
                max_new_tokens: entry.request.parameters.max_new_tokens,
            });

            ids.push(*id);
        }

        if requests.is_empty() {
            None
        } else {
            Some((ids, requests))
        }
    }
}

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
100
101
impl Db {
    pub(crate) fn new() -> Self {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
102
        // Shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
103
        let shared = Arc::new(Shared {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
104
            state: Mutex::new(State {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
105
106
107
108
109
110
111
112
113
114
                entries: BTreeMap::new(),
                next_id: 0,
                next_batch_id: 0,
                next_batch_start_id: 0,
            }),
        });

        Self { shared }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
115
    /// Append an entry to the database
Olivier Dehaene's avatar
Olivier Dehaene committed
116
    pub(crate) fn append(&self, entry: Entry) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
117
118
        // Acquire lock
        let mut state = self.shared.state.lock();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
119

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
120
        // Insert entry
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
121
122
        let id = state.next_id;
        state.next_id += 1;
Olivier Dehaene's avatar
Olivier Dehaene committed
123
        state.entries.insert(id, entry);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
124
125
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
126
    /// Remove an entry from the database if it exists
Olivier Dehaene's avatar
Olivier Dehaene committed
127
    pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
128
        let mut state = self.shared.state.lock();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
129
130
131
        state.entries.remove(id)
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    // Get the next batch
    pub(crate) fn next_batch(
        &self,
        min_size: Option<usize>,
        max_size: usize,
        min_waiting_time: Option<Duration>,
    ) -> Option<(Vec<u64>, Batch)> {
        // Acquire lock
        let mut state = self.shared.state.lock();

        // Get requests from the database
        if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) {
            if let Some(min_size) = min_size {
                // If min_size is set, only return a batch if there are enough requests
                if requests.len() < min_size {
                    return None;
                }
            }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
150

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
151
            // Batch size
Olivier Dehaene's avatar
Olivier Dehaene committed
152
            let size = requests.len();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
153
154
            // Longest input length for all requests in batch size
            // Used for padding inside the inference server
Olivier Dehaene's avatar
Olivier Dehaene committed
155
            let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
156
157
158
            let batch = Batch {
                id: state.next_batch_id,
                requests,
Olivier Dehaene's avatar
Olivier Dehaene committed
159
160
                size: size as u32,
                max_sequence_length,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
161
            };
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
162
163
164
            // Update next_batch_start_id to the last id in the batch + 1
            state.next_batch_start_id = ids.last().unwrap() + 1;
            // Increment batch id
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
165
            state.next_batch_id += 1;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
166
167

            return Some((ids, batch));
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
168
169
170
        }
        None
    }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
171
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
172

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
173
174
175
176
177
178
179
impl From<GenerateParameters> for LogitsWarperParameters {
    fn from(parameters: GenerateParameters) -> Self {
        Self {
            temperature: parameters.temperature,
            top_k: parameters.top_k as u32,
            top_p: parameters.top_p,
            do_sample: parameters.do_sample,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
180
181
182
        }
    }
}