batcher.rs 7.65 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
2
/// Batching and inference logic
use crate::GenerateRequest;
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
use crate::{Db, Entry};
use axum::http::StatusCode;
Olivier Dehaene's avatar
Olivier Dehaene committed
5
6
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
use std::sync::Arc;
Olivier Dehaene's avatar
Olivier Dehaene committed
8
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
9
use tokio::sync::{oneshot, Notify};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
10
11
use tokio::time::Instant;
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
12

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
/// Batcher
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
14
#[derive(Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
15
pub struct Batcher {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
16
    /// Request database
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
17
    db: Db,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
    /// Shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
19
20
21
    shared: Arc<Shared>,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
22
/// Batcher shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
23
struct Shared {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
24
    /// Batching background Tokio task notifier
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25
26
27
    batching_task: Notify,
}

Olivier Dehaene's avatar
Olivier Dehaene committed
28
impl Batcher {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
29
30
31
    pub(crate) fn new(
        client: ShardedClient,
        max_batch_size: usize,
32
        max_waiting_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
33
34
    ) -> Self {
        // Batcher shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
35
36
37
38
39
        let db = Db::new();
        let shared = Arc::new(Shared {
            batching_task: Notify::new(),
        });

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
40
41
42
        // Spawn batching background task that contains all the inference logic
        tokio::spawn(batching_task(
            max_batch_size,
43
            max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
44
45
46
47
            client,
            db.clone(),
            shared.clone(),
        ));
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
48
49
50
51

        Self { db, shared }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
52
    /// Add a new request to the database and return a future that will generate the text
Olivier Dehaene's avatar
Olivier Dehaene committed
53
54
55
56
    pub(crate) async fn infer(
        &self,
        input_length: usize,
        request: GenerateRequest,
57
    ) -> Result<InferResponse, InferError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
58
59
60
61
        // One shot channel to communicate with the background batching task
        let (response_tx, response_rx) = oneshot::channel();

        // Try to append the request to the database
Olivier Dehaene's avatar
Olivier Dehaene committed
62
63
        self.db.append(Entry {
            request,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
64
            response_tx,
Olivier Dehaene's avatar
Olivier Dehaene committed
65
            input_length,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
66
            time: Instant::now(),
67
            batch_time: None,
Olivier Dehaene's avatar
Olivier Dehaene committed
68
        });
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
69
70
71

        // Notify the background task that we have a new entry in the database that needs
        // to be batched
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
72
        self.shared.batching_task.notify_waiters();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
73
74
75
76

        // Await on the response from the background task
        // We can safely unwrap as the background task will never drop the sender
        match response_rx.await.unwrap() {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
77
            Ok(output) => Ok(output),
Olivier Dehaene's avatar
Olivier Dehaene committed
78
            Err(err) => Err(InferError::GenerationError(err.to_string())),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
79
80
81
82
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
83
84
85
86
87
88
89
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[instrument(skip(client, db, shared))]
async fn batching_task(
    max_batch_size: usize,
90
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
91
92
93
94
95
    client: ShardedClient,
    db: Db,
    shared: Arc<Shared>,
) {
    // Minimum batch size after which we try to add more requests
Olivier Dehaene's avatar
Olivier Dehaene committed
96
97
    let limit_min_batch_size = (max_batch_size / 2) as u32;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
98
    // Infinite loop
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
99
    loop {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
100
        // Wait for a notification from the Batcher struct
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
101
102
        shared.batching_task.notified().await;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
103
104
105
        // Get the next batch from the DB
        // This batch might be smaller than the maximum batch size if there are not enough requests
        // waiting in the DB
106
107
        let mut waiting_tokens = 0;
        if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
108
            let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
109
            waiting_tokens += 1;
Olivier Dehaene's avatar
Olivier Dehaene committed
110

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
111
112
            // We loop until we do not receive any cached batch from the inference server (== until
            // all requests have met their stopping criteria)
Olivier Dehaene's avatar
Olivier Dehaene committed
113
            while let Some(batch) = cached_batch {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
114
115
                // Get current batch info
                let batch_size = batch.size;
Olivier Dehaene's avatar
Olivier Dehaene committed
116
117
118
                let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
                let mut batches = vec![batch];

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
119
120
                // If the current batch is too small, we try to add more requests to it
                if batch_size <= limit_min_batch_size {
121
122
123
124
125
126
127
128
129
                    let min_size = match waiting_tokens {
                        // If we didn't onboard any new requests since >= max_waiting_tokens, we try
                        // to add a new batch even though its size might be small
                        _ if waiting_tokens >= max_waiting_tokens => None,
                        // Minimum size criteria
                        _ => Some(limit_min_batch_size as usize),
                    };

                    // Try to get a new batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
130
                    if let Some((new_request_ids, new_batch)) =
131
                        db.next_batch(min_size, max_batch_size)
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
132
                    {
133
134
                        // Reset waiting counter
                        waiting_tokens = 0;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
135
                        // Generate one token for this new batch to have the attention past in cache
Olivier Dehaene's avatar
Olivier Dehaene committed
136
                        let new_cached_batch =
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
137
138
139
140
141
142
143
                            wrap_future(client.generate(new_batch), new_request_ids, &db).await;
                        // Extend current batch with the new batch
                        if let Some(new_cached_batch) = new_cached_batch {
                            request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
                            batches.push(new_cached_batch);
                        }
                    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
144
                }
Olivier Dehaene's avatar
Olivier Dehaene committed
145

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
146
147
                cached_batch =
                    wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
148
                waiting_tokens += 1;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
149
150
151
152
153
            }
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
154
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
Olivier Dehaene's avatar
Olivier Dehaene committed
155
156
157
async fn wrap_future(
    future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
    request_ids: Vec<u64>,
Olivier Dehaene's avatar
Olivier Dehaene committed
158
    db: &Db,
Olivier Dehaene's avatar
Olivier Dehaene committed
159
160
161
162
163
) -> Option<Batch> {
    match future.await {
        Ok((generated_texts, next_batch)) => {
            send_generated(generated_texts, db);
            next_batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
164
        }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
165
        // If we have an error, we discard the whole batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
166
        Err(err) => {
Olivier Dehaene's avatar
Olivier Dehaene committed
167
            send_error(err, request_ids, db);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
168
169
170
171
172
            None
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
173
/// Send errors to the Batcher for all `request_ids`
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
174
175
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
    request_ids.into_iter().for_each(|id| {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
176
        // We can `expect` here as the request id should always be in the DB
Olivier Dehaene's avatar
Olivier Dehaene committed
177
178
179
        let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
        // unwrap_or is valid here as we don't care if the receiver is gone.
        entry.response_tx.send(Err(error.clone())).unwrap_or(());
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
180
181
182
    });
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
183
/// Send `generated_text` to the Batcher for all `finished`
Olivier Dehaene's avatar
Olivier Dehaene committed
184
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
185
    finished.into_iter().for_each(|output| {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
186
        // We can `expect` here as the request id should always be in the DB
Olivier Dehaene's avatar
Olivier Dehaene committed
187
188
189
        let entry = db
            .remove(&output.request.unwrap().id)
            .expect("ID not found in db. This is a bug.");
190
191
192
193
194
195
        let response = InferResponse {
            output: output.output,
            queued: entry.time,
            start: entry.batch_time.unwrap(), // unwrap is always valid
            end: Instant::now(),
        };
Olivier Dehaene's avatar
Olivier Dehaene committed
196
        // unwrap_or is valid here as we don't care if the receiver is gone.
197
        entry.response_tx.send(Ok(response)).unwrap_or(());
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
198
199
    });
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
200

201
202
203
204
205
206
207
208
#[derive(Debug)]
pub(crate) struct InferResponse {
    pub(crate) output: String,
    pub(crate) queued: Instant,
    pub(crate) start: Instant,
    pub(crate) end: Instant,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#[derive(Debug, Error)]
pub enum InferError {
    #[error("Request failed during generation: {0}")]
    GenerationError(String),
}

/// Convert to Axum supported format
impl From<InferError> for (StatusCode, String) {
    fn from(err: InferError) -> Self {
        match err {
            InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
        }
    }
}