batcher.rs 7.91 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Batching and inference logic
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use crate::{Db, Entry};
3
use crate::{ErrorResponse, GenerateRequest};
Olivier Dehaene's avatar
Olivier Dehaene committed
4
use axum::http::StatusCode;
5
use axum::Json;
6
use nohash_hasher::IntMap;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use std::future::Future;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
8
use std::sync::Arc;
9
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
11
use tokio::sync::{oneshot, Notify};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
12
13
use tokio::time::Instant;
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
14

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
15
/// Batcher
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
16
#[derive(Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
17
pub struct Batcher {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
    /// Request database
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
19
    db: Db,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
20
    /// Shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
21
22
23
    shared: Arc<Shared>,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
24
/// Batcher shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25
struct Shared {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
26
    /// Batching background Tokio task notifier
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
27
28
29
    batching_task: Notify,
}

Olivier Dehaene's avatar
Olivier Dehaene committed
30
impl Batcher {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
31
32
33
    pub(crate) fn new(
        client: ShardedClient,
        max_batch_size: usize,
34
        max_waiting_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
36
    ) -> Self {
        // Batcher shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
37
38
39
40
41
        let db = Db::new();
        let shared = Arc::new(Shared {
            batching_task: Notify::new(),
        });

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
42
43
        // Spawn batching background task that contains all the inference logic
        tokio::spawn(batching_task(
44
            client,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
45
            max_batch_size,
46
            max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
47
48
49
            db.clone(),
            shared.clone(),
        ));
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
50
51
52
53

        Self { db, shared }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
54
    /// Add a new request to the database and return a future that will generate the text
Olivier Dehaene's avatar
Olivier Dehaene committed
55
56
57
58
    pub(crate) async fn infer(
        &self,
        input_length: usize,
        request: GenerateRequest,
59
    ) -> Result<InferResponse, InferError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
60
61
62
63
        // One shot channel to communicate with the background batching task
        let (response_tx, response_rx) = oneshot::channel();

        // Try to append the request to the database
Olivier Dehaene's avatar
Olivier Dehaene committed
64
65
        self.db.append(Entry {
            request,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
66
            response_tx,
Olivier Dehaene's avatar
Olivier Dehaene committed
67
            input_length,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
68
            time: Instant::now(),
69
            batch_time: None,
Olivier Dehaene's avatar
Olivier Dehaene committed
70
        });
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
71
72
73

        // Notify the background task that we have a new entry in the database that needs
        // to be batched
74
        self.shared.batching_task.notify_one();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
75
76
77

        // Await on the response from the background task
        // We can safely unwrap as the background task will never drop the sender
78
79
80
81
        response_rx
            .await
            .unwrap()
            .map_err(|err| InferError::GenerationError(err.to_string()))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
82
83
84
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
85
86
87
88
89
90
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[instrument(skip(client, db, shared))]
async fn batching_task(
91
    mut client: ShardedClient,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
92
    max_batch_size: usize,
93
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
94
95
96
97
    db: Db,
    shared: Arc<Shared>,
) {
    // Minimum batch size after which we try to add more requests
Olivier Dehaene's avatar
Olivier Dehaene committed
98
99
    let limit_min_batch_size = (max_batch_size / 2) as u32;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
100
    // Infinite loop
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
101
    loop {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
102
        // Wait for a notification from the Batcher struct
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
103
104
        shared.batching_task.notified().await;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
105
106
107
        // Get the next batch from the DB
        // This batch might be smaller than the maximum batch size if there are not enough requests
        // waiting in the DB
108
109
        while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) {
            let mut cached_batch = wrap_future(client.generate(batch), &mut entries).await;
110
            let mut waiting_tokens = 1;
Olivier Dehaene's avatar
Olivier Dehaene committed
111

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
112
113
            // We loop until we do not receive any cached batch from the inference server (== until
            // all requests have met their stopping criteria)
Olivier Dehaene's avatar
Olivier Dehaene committed
114
            while let Some(batch) = cached_batch {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
115
116
                // Get current batch info
                let batch_size = batch.size;
Olivier Dehaene's avatar
Olivier Dehaene committed
117
118
                let mut batches = vec![batch];

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
119
120
                // If the current batch is too small, we try to add more requests to it
                if batch_size <= limit_min_batch_size {
121
122
123
124
125
126
127
128
129
                    let min_size = match waiting_tokens {
                        // If we didn't onboard any new requests since >= max_waiting_tokens, we try
                        // to add a new batch even though its size might be small
                        _ if waiting_tokens >= max_waiting_tokens => None,
                        // Minimum size criteria
                        _ => Some(limit_min_batch_size as usize),
                    };

                    // Try to get a new batch
130
                    if let Some((mut new_entries, new_batch)) =
131
                        db.next_batch(min_size, max_batch_size - batch_size as usize)
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
132
133
                    {
                        // Generate one token for this new batch to have the attention past in cache
Olivier Dehaene's avatar
Olivier Dehaene committed
134
                        let new_cached_batch =
135
                            wrap_future(client.generate(new_batch), &mut new_entries).await;
136
137
                        // Reset waiting counter
                        waiting_tokens = 1;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
138
139
                        // Extend current batch with the new batch
                        if let Some(new_cached_batch) = new_cached_batch {
140
                            entries.extend(new_entries);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
141
142
143
                            batches.push(new_cached_batch);
                        }
                    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
144
                }
Olivier Dehaene's avatar
Olivier Dehaene committed
145

146
                cached_batch = wrap_future(client.generate_with_cache(batches), &mut entries).await;
147
                waiting_tokens += 1;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
148
149
150
151
152
            }
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
153
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
Olivier Dehaene's avatar
Olivier Dehaene committed
154
155
async fn wrap_future(
    future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
156
    entries: &mut IntMap<u64, Entry>,
Olivier Dehaene's avatar
Olivier Dehaene committed
157
158
159
) -> Option<Batch> {
    match future.await {
        Ok((generated_texts, next_batch)) => {
160
            send_generated(generated_texts, entries);
Olivier Dehaene's avatar
Olivier Dehaene committed
161
            next_batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
162
        }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
163
        // If we have an error, we discard the whole batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
164
        Err(err) => {
165
            send_error(err, entries);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
166
167
168
169
170
            None
        }
    }
}

171
172
173
/// Send errors to the Batcher for all `entries`
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
    entries.drain().for_each(|(_, entry)| {
Olivier Dehaene's avatar
Olivier Dehaene committed
174
175
        // unwrap_or is valid here as we don't care if the receiver is gone.
        entry.response_tx.send(Err(error.clone())).unwrap_or(());
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
176
177
178
    });
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
179
/// Send `generated_text` to the Batcher for all `finished`
180
fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>) {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
181
    finished.into_iter().for_each(|output| {
182
183
        // We can `expect` here as the request id should always be in the entries
        let entry = entries
Olivier Dehaene's avatar
Olivier Dehaene committed
184
            .remove(&output.request.unwrap().id)
185
            .expect("ID not found in entries. This is a bug.");
OlivierDehaene's avatar
OlivierDehaene committed
186

187
        let response = InferResponse {
OlivierDehaene's avatar
OlivierDehaene committed
188
189
190
            output_text: output.output_text,
            generated_tokens: output.generated_tokens,
            token_ids: output.token_ids,
191
            tokens: output.tokens,
OlivierDehaene's avatar
OlivierDehaene committed
192
            logprobs: output.logprobs,
193
            finish_reason: output.finish_reason,
194
195
196
197
            queued: entry.time,
            start: entry.batch_time.unwrap(), // unwrap is always valid
            end: Instant::now(),
        };
Olivier Dehaene's avatar
Olivier Dehaene committed
198
        // unwrap_or is valid here as we don't care if the receiver is gone.
199
        entry.response_tx.send(Ok(response)).unwrap_or(());
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
200
201
    });
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
202

203
204
#[derive(Debug)]
pub(crate) struct InferResponse {
OlivierDehaene's avatar
OlivierDehaene committed
205
206
207
208
209
    pub(crate) output_text: String,
    pub(crate) generated_tokens: u32,
    pub(crate) token_ids: Vec<u32>,
    pub(crate) tokens: Vec<String>,
    pub(crate) logprobs: Vec<f32>,
210
    pub(crate) finish_reason: String,
211
212
213
214
215
    pub(crate) queued: Instant,
    pub(crate) start: Instant,
    pub(crate) end: Instant,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
216
217
218
219
220
221
222
#[derive(Debug, Error)]
pub enum InferError {
    #[error("Request failed during generation: {0}")]
    GenerationError(String),
}

/// Convert to Axum supported format
223
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
224
225
    fn from(err: InferError) -> Self {
        match err {
226
            InferError::GenerationError(_) => (
227
                StatusCode::FAILED_DEPENDENCY,
228
229
230
231
                Json(ErrorResponse {
                    error: err.to_string(),
                }),
            ),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
232
233
234
        }
    }
}