db.rs 3.66 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
/// This code is massively inspired by Tokio mini-redis
use crate::GenerateRequest;
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
use parking_lot::RwLock;
use std::collections::BTreeMap;
use std::sync::Arc;
use tokio::sync::oneshot::Sender;

#[derive(Debug, Clone)]
pub(crate) struct Db {
    pub shared: Arc<Shared>,
}

#[derive(Debug)]
pub struct Shared {
    state: RwLock<State>,
}

#[derive(Debug)]
struct State {
    entries: BTreeMap<u64, (Request, Sender<Result<String, ClientError>>)>,

    /// Identifier to use for the next expiration. Each expiration is associated
    /// with a unique identifier. See above for why.
    next_id: u64,

    next_batch_id: u64,

    /// Current batch id
    next_batch_start_id: u64,
}

impl Db {
    pub(crate) fn new() -> Self {
        let shared = Arc::new(Shared {
            state: RwLock::new(State {
                entries: BTreeMap::new(),
                next_id: 0,
                next_batch_id: 0,
                next_batch_start_id: 0,
            }),
        });

        Self { shared }
    }

    pub(crate) fn append(&self, request: GenerateRequest, sender: Sender<Result<String, ClientError>>) {
        let mut state = self.shared.state.write();

        let id = state.next_id;
        state.next_id += 1;

        let parameters = Some(LogitsWarperParameters {
            temperature: request.parameters.temperature,
            top_k: request.parameters.top_k,
            top_p: request.parameters.top_p,
            do_sample: request.parameters.do_sample,
        });
        let request = Request {
            id,
            inputs: request.inputs,
            parameters,
            max_new_tokens: request.parameters.max_new_tokens,
        };
        state.entries.insert(id, (request, sender));
    }

    pub(crate) fn remove(&self, id: &u64) -> Option<(Request, Sender<Result<String, ClientError>>)> {
        let mut state = self.shared.state.write();
        state.entries.remove(id)
    }

    pub(crate) fn len(&self) -> usize {
        let state = self.shared.state.read();
        state.entries.len()
    }

    fn next_requests(&self, max_size: usize) -> Option<(u64, Vec<Request>)> {
        let state = self.shared.state.read();

        let requests: Vec<Request> = state
            .entries
            .range(state.next_batch_start_id..)
            .take(max_size)
            .map(|(_, (request, _))| request.clone())
            .collect();

        if requests.is_empty() {
            None
        } else {
            let last_id = requests.last().unwrap().id;
            Some((last_id, requests))
        }
    }

    pub(crate) fn next_batch(&self, max_size: usize) -> Option<Batch> {
        if let Some((last_id, requests)) = self.next_requests(max_size) {
            let mut state = self.shared.state.write();
            let batch = Batch {
                id: state.next_batch_id,
                requests,
            };
            state.next_batch_start_id = last_id + 1;
            state.next_batch_id += 1;
            return Some(batch);
        }
        None
    }

    pub(crate) fn next_batch_minimum_size(
        &self,
        min_size: usize,
        max_size: usize,
    ) -> Option<Batch> {
        if let Some((last_id, requests)) = self.next_requests(max_size) {
            if requests.len() >= min_size {
                let mut state = self.shared.state.write();
                let batch = Batch {
                    id: state.next_batch_id,
                    requests,
                };
                state.next_batch_start_id = last_id + 1;
                state.next_batch_id += 1;
                return Some(batch);
            }
        }
        None
    }
}