batcher.rs 4.04 KB
Newer Older
Olivier Dehaene's avatar
Olivier Dehaene committed
1
use crate::server::GenerateRequest;
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use crate::Db;
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
5
use std::sync::Arc;
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use tokio::sync::{oneshot, Notify};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
8
9
10
11
12

const MAX_LENGTH: usize = 128;

pub struct InferError {}

#[derive(Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
13
pub(crate) struct Batcher {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
14
15
16
17
18
19
20
21
    db: Db,
    shared: Arc<Shared>,
}

struct Shared {
    batching_task: Notify,
}

Olivier Dehaene's avatar
Olivier Dehaene committed
22
impl Batcher {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
23
24
25
26
27
28
29
30
31
32
33
    pub(crate) fn new(client: ShardedClient) -> Self {
        let db = Db::new();
        let shared = Arc::new(Shared {
            batching_task: Notify::new(),
        });

        tokio::spawn(batching_task(client, db.clone(), shared.clone()));

        Self { db, shared }
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
34
35
36
37
38
    pub(crate) async fn infer(
        &self,
        input_length: usize,
        request: GenerateRequest,
    ) -> Result<String, InferError> {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39
40
41
42
        if self.db.len() > MAX_LENGTH {
            return Err(InferError {});
        }
        let (request_tx, request_rx) = oneshot::channel();
Olivier Dehaene's avatar
Olivier Dehaene committed
43
        self.db.append(input_length, request, request_tx);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
44
45
46
        self.shared.batching_task.notify_waiters();
        match request_rx.await.unwrap() {
            Ok(output) => Ok(output),
Olivier Dehaene's avatar
Olivier Dehaene committed
47
            Err(_) => Err(InferError {}),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
48
49
50
51
52
53
54
55
56
        }
    }
}

async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
    loop {
        shared.batching_task.notified().await;

        if let Some(batch) = db.next_batch(32) {
Olivier Dehaene's avatar
Olivier Dehaene committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
            let request_ids = batch.requests.iter().map(|req| req.id).collect();
            let mut cached_batch = match batch.size {
                size if size > 16 => {
                    wrap_future(client.generate_until_finished(batch), request_ids, &db).await
                }
                _ => wrap_future(client.generate(batch), request_ids, &db).await,
            };

            while let Some(batch) = cached_batch {
                let batch_size = batch.size;
                let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
                let mut batches = vec![batch];

                if batch_size <= 16 {
                    if let Some(new_batch) = db.next_batch_minimum_size(16, 48) {
                        let new_batch_request_ids =
                            new_batch.requests.iter().map(|req| req.id).collect();
                        let new_cached_batch =
                            wrap_future(client.generate(new_batch), new_batch_request_ids, &db)
                                .await;
                        if let Some(new_cached_batch) = new_cached_batch {
                            request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
                            batches.push(new_cached_batch);
                        }
                    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
82
                }
Olivier Dehaene's avatar
Olivier Dehaene committed
83
84
85
86
87
88
89

                cached_batch = match batch_size {
                    size if size > 16 => {
                        wrap_future(client.generate_until_finished_with_cache(batches), request_ids, &db).await
                    }
                    _ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
                };
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
90
91
92
93
94
            }
        }
    }
}

Olivier Dehaene's avatar
Olivier Dehaene committed
95
96
97
async fn wrap_future(
    future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
    request_ids: Vec<u64>,
Olivier Dehaene's avatar
Olivier Dehaene committed
98
    db: &Db,
Olivier Dehaene's avatar
Olivier Dehaene committed
99
100
101
102
103
) -> Option<Batch> {
    match future.await {
        Ok((generated_texts, next_batch)) => {
            send_generated(generated_texts, db);
            next_batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
104
105
106
        }
        Err(err) => {
            println!("{:?}", err);
Olivier Dehaene's avatar
Olivier Dehaene committed
107
            send_error(err, request_ids, db);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
108
109
110
111
112
113
114
115
116
117
118
119
            None
        }
    }
}

fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
    request_ids.into_iter().for_each(|id| {
        let (_, response_tx) = db.remove(&id).unwrap();
        response_tx.send(Err(error.clone())).unwrap_or(());
    });
}

Olivier Dehaene's avatar
Olivier Dehaene committed
120
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
121
    finished.into_iter().for_each(|output| {
Olivier Dehaene's avatar
Olivier Dehaene committed
122
        let (_, response_tx) = db.remove(&output.request.unwrap().id).unwrap();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
123
124
125
        response_tx.send(Ok(output.output)).unwrap_or(());
    });
}