batcher.rs 5.01 KB
Newer Older
Olivier Dehaene's avatar
Olivier Dehaene committed
1
use crate::server::GenerateRequest;
Olivier Dehaene's avatar
Olivier Dehaene committed
2
3
use crate::{Db, Entry};
use axum::http::StatusCode;
Olivier Dehaene's avatar
Olivier Dehaene committed
4
5
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
6
use std::sync::Arc;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
8
use tokio::sync::{oneshot, Notify};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
10
11

const MAX_LENGTH: usize = 128;

Olivier Dehaene's avatar
Olivier Dehaene committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#[derive(Debug, Error)]
pub enum InferError {
    #[error("Request failed during generation: {0}")]
    GenerationError(String),
    #[error("Model is overloaded")]
    Overloaded,
}

impl From<InferError> for (StatusCode, String) {
    fn from(err: InferError) -> Self {
        match err {
            InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
            InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()),
        }
    }
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
28
29

#[derive(Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
30
pub(crate) struct Batcher {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
31
32
33
34
35
36
37
38
    db: Db,
    shared: Arc<Shared>,
}

struct Shared {
    batching_task: Notify,
}

Olivier Dehaene's avatar
Olivier Dehaene committed
39
impl Batcher {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
40
41
42
43
44
45
46
47
48
49
50
    pub(crate) fn new(client: ShardedClient) -> Self {
        let db = Db::new();
        let shared = Arc::new(Shared {
            batching_task: Notify::new(),
        });

        tokio::spawn(batching_task(client, db.clone(), shared.clone()));

        Self { db, shared }
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
51
52
53
54
55
    pub(crate) async fn infer(
        &self,
        input_length: usize,
        request: GenerateRequest,
    ) -> Result<String, InferError> {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
56
        if self.db.len() > MAX_LENGTH {
Olivier Dehaene's avatar
Olivier Dehaene committed
57
            return Err(InferError::Overloaded);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
58
59
        }
        let (request_tx, request_rx) = oneshot::channel();
Olivier Dehaene's avatar
Olivier Dehaene committed
60
61
62
63
64
        self.db.append(Entry {
            request,
            response_tx: request_tx,
            input_length,
        });
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
65
66
67
        self.shared.batching_task.notify_waiters();
        match request_rx.await.unwrap() {
            Ok(output) => Ok(output),
Olivier Dehaene's avatar
Olivier Dehaene committed
68
            Err(err) => Err(InferError::GenerationError(err.to_string())),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
69
70
71
72
73
74
75
76
77
        }
    }
}

async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
    loop {
        shared.batching_task.notified().await;

        if let Some(batch) = db.next_batch(32) {
Olivier Dehaene's avatar
Olivier Dehaene committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
            let request_ids = batch.requests.iter().map(|req| req.id).collect();
            let mut cached_batch = match batch.size {
                size if size > 16 => {
                    wrap_future(client.generate_until_finished(batch), request_ids, &db).await
                }
                _ => wrap_future(client.generate(batch), request_ids, &db).await,
            };

            while let Some(batch) = cached_batch {
                let batch_size = batch.size;
                let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
                let mut batches = vec![batch];

                if batch_size <= 16 {
                    if let Some(new_batch) = db.next_batch_minimum_size(16, 48) {
                        let new_batch_request_ids =
                            new_batch.requests.iter().map(|req| req.id).collect();
                        let new_cached_batch =
                            wrap_future(client.generate(new_batch), new_batch_request_ids, &db)
                                .await;
                        if let Some(new_cached_batch) = new_cached_batch {
                            request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
                            batches.push(new_cached_batch);
                        }
                    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
103
                }
Olivier Dehaene's avatar
Olivier Dehaene committed
104
105
106

                cached_batch = match batch_size {
                    size if size > 16 => {
Olivier Dehaene's avatar
Olivier Dehaene committed
107
108
109
110
111
112
                        wrap_future(
                            client.generate_until_finished_with_cache(batches),
                            request_ids,
                            &db,
                        )
                        .await
Olivier Dehaene's avatar
Olivier Dehaene committed
113
114
115
                    }
                    _ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
                };
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
116
117
118
119
120
            }
        }
    }
}

Olivier Dehaene's avatar
Olivier Dehaene committed
121
122
123
async fn wrap_future(
    future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
    request_ids: Vec<u64>,
Olivier Dehaene's avatar
Olivier Dehaene committed
124
    db: &Db,
Olivier Dehaene's avatar
Olivier Dehaene committed
125
126
127
128
129
) -> Option<Batch> {
    match future.await {
        Ok((generated_texts, next_batch)) => {
            send_generated(generated_texts, db);
            next_batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
130
131
        }
        Err(err) => {
Olivier Dehaene's avatar
Olivier Dehaene committed
132
            send_error(err, request_ids, db);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
133
134
135
136
137
138
139
            None
        }
    }
}

fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
    request_ids.into_iter().for_each(|id| {
Olivier Dehaene's avatar
Olivier Dehaene committed
140
141
142
        let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
        // unwrap_or is valid here as we don't care if the receiver is gone.
        entry.response_tx.send(Err(error.clone())).unwrap_or(());
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
143
144
145
    });
}

Olivier Dehaene's avatar
Olivier Dehaene committed
146
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
147
    finished.into_iter().for_each(|output| {
Olivier Dehaene's avatar
Olivier Dehaene committed
148
149
150
151
152
        let entry = db
            .remove(&output.request.unwrap().id)
            .expect("ID not found in db. This is a bug.");
        // unwrap_or is valid here as we don't care if the receiver is gone.
        entry.response_tx.send(Ok(output.output)).unwrap_or(());
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
153
154
    });
}