db.rs 5.18 KB
Newer Older
1
use crate::InferResponse;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
/// This code is massively inspired by Tokio mini-redis
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
3
4
use crate::{GenerateParameters, GenerateRequest};
use parking_lot::Mutex;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
5
6
use std::collections::BTreeMap;
use std::sync::Arc;
7
use text_generation_client::{
OlivierDehaene's avatar
OlivierDehaene committed
8
    Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
9
};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
10
use tokio::sync::oneshot::Sender;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
11
use tokio::time::Instant;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
12

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
/// Database entry
Olivier Dehaene's avatar
Olivier Dehaene committed
14
15
#[derive(Debug)]
pub(crate) struct Entry {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
16
    /// Request
Olivier Dehaene's avatar
Olivier Dehaene committed
17
    pub request: GenerateRequest,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
    /// Response sender to communicate between the Batcher and the batching_task
19
    pub response_tx: Sender<Result<InferResponse, ClientError>>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
20
    /// Number of tokens in the input
Olivier Dehaene's avatar
Olivier Dehaene committed
21
    pub input_length: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
22
23
    /// Instant when this entry was created
    pub time: Instant,
24
25
    /// Instant when this entry was added to a batch
    pub batch_time: Option<Instant>,
Olivier Dehaene's avatar
Olivier Dehaene committed
26
27
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
28
/// Request Database
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
29
30
31
32
33
#[derive(Debug, Clone)]
pub(crate) struct Db {
    pub shared: Arc<Shared>,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
34
/// Shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
35
36
#[derive(Debug)]
pub struct Shared {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
    state: Mutex<State>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
38
39
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
40
/// Database State
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41
42
#[derive(Debug)]
struct State {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
43
    /// Database entries organized in a BTreeMap to be able to iterate over them in order
Olivier Dehaene's avatar
Olivier Dehaene committed
44
    entries: BTreeMap<u64, Entry>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
45

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
46
    /// Id of the next entry
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
47
48
    next_id: u64,

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
49
    /// Id of the next batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
50
51
    next_batch_id: u64,

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
52
    /// Start ID of the next batch. Used to iterate inside the entries BTreeMap
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
53
54
55
    next_batch_start_id: u64,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
56
57
impl State {
    /// Get the next requests
58
    fn next_requests(&self, max_size: usize) -> Option<(Vec<u64>, Vec<Request>)> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        // Iterates for max_size over the BTreemap starting from next_batch_start_id
        let mut requests = Vec::new();
        let mut ids = Vec::new();

        for (id, entry) in self
            .entries
            // Start from next_batch_start_id
            .range(self.next_batch_start_id..)
            // Take max_size
            .take(max_size)
        {
            requests.push(Request {
                id: *id,
                inputs: entry.request.inputs.clone(),
                input_length: entry.input_length as u32,
OlivierDehaene's avatar
OlivierDehaene committed
74
                parameters: Some(NextTokenChooserParameters::from(
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
75
76
                    entry.request.parameters.clone(),
                )),
77
78
79
                stopping_parameters: Some(StoppingCriteriaParameters::from(
                    entry.request.parameters.clone(),
                )),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
80
81
82
83
84
85
86
87
88
89
90
91
92
            });

            ids.push(*id);
        }

        if requests.is_empty() {
            None
        } else {
            Some((ids, requests))
        }
    }
}

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
93
94
impl Db {
    pub(crate) fn new() -> Self {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
95
        // Shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
96
        let shared = Arc::new(Shared {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
97
            state: Mutex::new(State {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
98
99
100
101
102
103
104
105
106
107
                entries: BTreeMap::new(),
                next_id: 0,
                next_batch_id: 0,
                next_batch_start_id: 0,
            }),
        });

        Self { shared }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
108
    /// Append an entry to the database
Olivier Dehaene's avatar
Olivier Dehaene committed
109
    pub(crate) fn append(&self, entry: Entry) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
110
111
        // Acquire lock
        let mut state = self.shared.state.lock();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
112

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
113
        // Insert entry
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
114
115
        let id = state.next_id;
        state.next_id += 1;
Olivier Dehaene's avatar
Olivier Dehaene committed
116
        state.entries.insert(id, entry);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
117
118
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
119
    /// Remove an entry from the database if it exists
Olivier Dehaene's avatar
Olivier Dehaene committed
120
    pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
121
        let mut state = self.shared.state.lock();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
122
123
124
        state.entries.remove(id)
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
125
126
127
128
129
130
131
132
133
134
    // Get the next batch
    pub(crate) fn next_batch(
        &self,
        min_size: Option<usize>,
        max_size: usize,
    ) -> Option<(Vec<u64>, Batch)> {
        // Acquire lock
        let mut state = self.shared.state.lock();

        // Get requests from the database
135
        if let Some((ids, requests)) = state.next_requests(max_size) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
136
137
138
139
140
141
            if let Some(min_size) = min_size {
                // If min_size is set, only return a batch if there are enough requests
                if requests.len() < min_size {
                    return None;
                }
            }
142
143
144
145
            ids.iter().for_each(|id| {
                // Set batch_time for each request
                state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now());
            });
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
146

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
147
            // Batch size
Olivier Dehaene's avatar
Olivier Dehaene committed
148
            let size = requests.len();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
149
150
151
            let batch = Batch {
                id: state.next_batch_id,
                requests,
Olivier Dehaene's avatar
Olivier Dehaene committed
152
                size: size as u32,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
153
            };
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
154
155
156
            // Update next_batch_start_id to the last id in the batch + 1
            state.next_batch_start_id = ids.last().unwrap() + 1;
            // Increment batch id
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
157
            state.next_batch_id += 1;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
158
159

            return Some((ids, batch));
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
160
161
162
        }
        None
    }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
163
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
164

OlivierDehaene's avatar
OlivierDehaene committed
165
impl From<GenerateParameters> for NextTokenChooserParameters {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
166
167
168
169
170
171
    fn from(parameters: GenerateParameters) -> Self {
        Self {
            temperature: parameters.temperature,
            top_k: parameters.top_k as u32,
            top_p: parameters.top_p,
            do_sample: parameters.do_sample,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
172
173
174
        }
    }
}
175
176
177
178
179
180
181
182
183

impl From<GenerateParameters> for StoppingCriteriaParameters {
    fn from(parameters: GenerateParameters) -> Self {
        Self {
            stop_sequences: parameters.stop,
            max_new_tokens: parameters.max_new_tokens,
        }
    }
}