db.rs 4.81 KB
Newer Older
1
use crate::InferResponse;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
/// This code is massively inspired by Tokio mini-redis
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
3
4
use crate::{GenerateParameters, GenerateRequest};
use parking_lot::Mutex;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
5
6
use std::collections::BTreeMap;
use std::sync::Arc;
7
use text_generation_client::{Batch, ClientError, LogitsWarperParameters, Request};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
8
use tokio::sync::oneshot::Sender;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
9
use tokio::time::Instant;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
10

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
11
/// Database entry
Olivier Dehaene's avatar
Olivier Dehaene committed
12
13
#[derive(Debug)]
pub(crate) struct Entry {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
    /// Request
Olivier Dehaene's avatar
Olivier Dehaene committed
15
    pub request: GenerateRequest,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
16
    /// Response sender to communicate between the Batcher and the batching_task
17
    pub response_tx: Sender<Result<InferResponse, ClientError>>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
    /// Number of tokens in the input
Olivier Dehaene's avatar
Olivier Dehaene committed
19
    pub input_length: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
20
21
    /// Instant when this entry was created
    pub time: Instant,
22
23
    /// Instant when this entry was added to a batch
    pub batch_time: Option<Instant>,
Olivier Dehaene's avatar
Olivier Dehaene committed
24
25
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
26
/// Request Database
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
27
28
29
30
31
#[derive(Debug, Clone)]
pub(crate) struct Db {
    pub shared: Arc<Shared>,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
32
/// Shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
33
34
#[derive(Debug)]
pub struct Shared {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
    state: Mutex<State>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
36
37
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
38
/// Database State
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39
40
#[derive(Debug)]
struct State {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
41
    /// Database entries organized in a BTreeMap to be able to iterate over them in order
Olivier Dehaene's avatar
Olivier Dehaene committed
42
    entries: BTreeMap<u64, Entry>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
43

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
44
    /// Id of the next entry
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
45
46
    next_id: u64,

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
47
    /// Id of the next batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
48
49
    next_batch_id: u64,

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
50
    /// Start ID of the next batch. Used to iterate inside the entries BTreeMap
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
51
52
53
    next_batch_start_id: u64,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
54
55
impl State {
    /// Get the next requests
56
    fn next_requests(&self, max_size: usize) -> Option<(Vec<u64>, Vec<Request>)> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        // Iterates for max_size over the BTreemap starting from next_batch_start_id
        let mut requests = Vec::new();
        let mut ids = Vec::new();

        for (id, entry) in self
            .entries
            // Start from next_batch_start_id
            .range(self.next_batch_start_id..)
            // Take max_size
            .take(max_size)
        {
            requests.push(Request {
                id: *id,
                inputs: entry.request.inputs.clone(),
                input_length: entry.input_length as u32,
                parameters: Some(LogitsWarperParameters::from(
                    entry.request.parameters.clone(),
                )),
                max_new_tokens: entry.request.parameters.max_new_tokens,
            });

            ids.push(*id);
        }

        if requests.is_empty() {
            None
        } else {
            Some((ids, requests))
        }
    }
}

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
89
90
impl Db {
    pub(crate) fn new() -> Self {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
91
        // Shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
92
        let shared = Arc::new(Shared {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
93
            state: Mutex::new(State {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
94
95
96
97
98
99
100
101
102
103
                entries: BTreeMap::new(),
                next_id: 0,
                next_batch_id: 0,
                next_batch_start_id: 0,
            }),
        });

        Self { shared }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
104
    /// Append an entry to the database
Olivier Dehaene's avatar
Olivier Dehaene committed
105
    pub(crate) fn append(&self, entry: Entry) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
106
107
        // Acquire lock
        let mut state = self.shared.state.lock();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
108

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
109
        // Insert entry
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
110
111
        let id = state.next_id;
        state.next_id += 1;
Olivier Dehaene's avatar
Olivier Dehaene committed
112
        state.entries.insert(id, entry);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
113
114
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
115
    /// Remove an entry from the database if it exists
Olivier Dehaene's avatar
Olivier Dehaene committed
116
    pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
117
        let mut state = self.shared.state.lock();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
118
119
120
        state.entries.remove(id)
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
121
122
123
124
125
126
127
128
129
130
    // Get the next batch
    pub(crate) fn next_batch(
        &self,
        min_size: Option<usize>,
        max_size: usize,
    ) -> Option<(Vec<u64>, Batch)> {
        // Acquire lock
        let mut state = self.shared.state.lock();

        // Get requests from the database
131
        if let Some((ids, requests)) = state.next_requests(max_size) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
132
133
134
135
136
137
            if let Some(min_size) = min_size {
                // If min_size is set, only return a batch if there are enough requests
                if requests.len() < min_size {
                    return None;
                }
            }
138
139
140
141
            ids.iter().for_each(|id| {
                // Set batch_time for each request
                state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now());
            });
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
142

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
143
            // Batch size
Olivier Dehaene's avatar
Olivier Dehaene committed
144
            let size = requests.len();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
145
146
147
            let batch = Batch {
                id: state.next_batch_id,
                requests,
Olivier Dehaene's avatar
Olivier Dehaene committed
148
                size: size as u32,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
149
            };
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
150
151
152
            // Update next_batch_start_id to the last id in the batch + 1
            state.next_batch_start_id = ids.last().unwrap() + 1;
            // Increment batch id
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
153
            state.next_batch_id += 1;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
154
155

            return Some((ids, batch));
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
156
157
158
        }
        None
    }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
159
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
160

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
161
162
163
164
165
166
167
impl From<GenerateParameters> for LogitsWarperParameters {
    fn from(parameters: GenerateParameters) -> Self {
        Self {
            temperature: parameters.temperature,
            top_k: parameters.top_k as u32,
            top_p: parameters.top_p,
            do_sample: parameters.do_sample,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
168
169
170
        }
    }
}