sharded_client.rs 8.75 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
use crate::client::{ClientError, Result};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
/// Multi shard Client
Nicolas Patry's avatar
Nicolas Patry committed
3
use crate::client::{Health, ShardInfo};
OlivierDehaene's avatar
OlivierDehaene committed
4

Nicolas Patry's avatar
Nicolas Patry committed
5
6
7
8
9
10
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::{
    Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
    NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use crate::client::{Chunk, InfoResponse, Input};
OlivierDehaene's avatar
OlivierDehaene committed
11
use async_trait::async_trait;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
12
13
use futures::future::join_all;
use tonic::transport::Uri;
14
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
15

16
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
17
/// Text Generation Inference gRPC multi client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18
pub struct ShardedClient {
19
    clients: Vec<Client>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
20
21
22
}

impl ShardedClient {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
23
    fn new(clients: Vec<Client>) -> Self {
24
        Self { clients }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25
26
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
27
28
    /// Create a new ShardedClient from a master client. The master client will communicate with
    /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
Olivier Dehaene's avatar
Olivier Dehaene committed
29
    async fn from_master_client(mut master_client: Client) -> Result<Self> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
        // Get all uris/unix sockets from the master client
31
        let uris = master_client.service_discovery().await?;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
32
        let futures = uris.into_iter().map(Client::connect_uds);
Olivier Dehaene's avatar
Olivier Dehaene committed
33
34
        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
        Ok(Self::new(clients?))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
35
36
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
    /// Returns a client connected to the given uri
Nicolas Patry's avatar
Nicolas Patry committed
38
    #[allow(dead_code)]
Olivier Dehaene's avatar
Olivier Dehaene committed
39
40
    pub async fn connect(uri: Uri) -> Result<Self> {
        let master_client = Client::connect(uri).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41
42
43
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
44
45
46
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
        let master_client = Client::connect_uds(path).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
47
48
49
        Self::from_master_client(master_client).await
    }

50
51
52
53
54
55
56
57
    /// Get the model info
    #[instrument(skip(self))]
    pub async fn info(&mut self) -> Result<ShardInfo> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.info())
            .collect();
OlivierDehaene's avatar
OlivierDehaene committed
58
        join_all(futures).await.pop().unwrap().map(ShardInfo::from)
59
60
    }

61
62
63
64
65
66
67
68
69
70
71
    /// GRPC health check
    #[instrument(skip(self))]
    pub async fn health(&mut self) -> Result<HealthResponse> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.health())
            .collect();
        join_all(futures).await.pop().unwrap()
    }

72
    /// Clear the past generations cache
73
    #[instrument(skip(self))]
74
    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
75
76
77
        let futures: Vec<_> = self
            .clients
            .iter_mut()
78
            .map(|client| client.clear_cache(batch_id))
79
80
81
82
            .collect();
        join_all(futures).await.into_iter().collect()
    }

83
84
85
86
87
    /// Filter a cached batch
    #[instrument(skip(self))]
    pub async fn filter_batch(
        &mut self,
        batch_id: u64,
88
89
        request_ids: Vec<u64>,
    ) -> Result<Option<CachedBatch>> {
90
91
92
        let futures: Vec<_> = self
            .clients
            .iter_mut()
93
            .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
94
95
96
97
98
            .collect();
        // all shards return the same message
        join_all(futures).await.pop().unwrap()
    }

99
100
101
102
103
104
105
106
    /// Warmup on a max size batch
    ///
    /// Returns the maximum amount of tokens supported by the hardware
    #[instrument(skip(self))]
    pub async fn warmup(
        &mut self,
        max_input_length: u32,
        max_prefill_tokens: u32,
OlivierDehaene's avatar
OlivierDehaene committed
107
        max_total_tokens: u32,
108
        max_batch_size: Option<usize>,
109
    ) -> Result<Option<u32>> {
110
111
112
        let futures: Vec<_> = self
            .clients
            .iter_mut()
OlivierDehaene's avatar
OlivierDehaene committed
113
            .map(|client| {
114
115
116
117
118
119
                Box::pin(client.warmup(
                    max_input_length,
                    max_prefill_tokens,
                    max_total_tokens,
                    max_batch_size,
                ))
OlivierDehaene's avatar
OlivierDehaene committed
120
            })
121
            .collect();
122
123
124
125
126
127
        // Take the minimum value
        let results = join_all(futures)
            .await
            .into_iter()
            .collect::<Result<Vec<Option<u32>>>>()?;
        Ok(results.into_iter().flatten().min())
128
129
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
130
131
    /// Generate one token for each request in the given batch
    ///
132
    /// Returns Generation for each request in batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
133
    /// and the next cached batch
134
    #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
135
136
137
    pub async fn prefill(
        &mut self,
        batch: Batch,
138
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
139
140
141
        let futures: Vec<_> = self
            .clients
            .iter_mut()
142
            .map(|client| Box::pin(client.prefill(batch.clone())))
143
            .collect();
144
        #[allow(clippy::type_complexity)]
145
        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
146
            join_all(futures).await.into_iter().collect();
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        let mut results = results?;

        let (mut generations, next_batch, mut timings) =
            results.pop().ok_or(ClientError::EmptyResults)?;

        // Merge generations from different model shards
        for (mut shard_generations, _, shard_timings) in results.into_iter() {
            generations.append(&mut shard_generations);
            // Return the timings of the slowest shard
            if shard_timings.total > timings.total {
                timings = shard_timings;
            }
        }
        Ok((generations, next_batch, timings))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
161
162
    }

163
    /// Generate one token for each request in the given cached batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
164
    ///
165
    /// Returns Generation for each request in batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
166
    /// and the next cached batch
167
    #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
168
    pub async fn decode(
169
        &mut self,
170
        batches: Vec<CachedBatch>,
171
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
172
173
174
        let futures: Vec<_> = self
            .clients
            .iter_mut()
175
            .map(|client| Box::pin(client.decode(batches.clone())))
176
            .collect();
177
        #[allow(clippy::type_complexity)]
178
        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
179
            join_all(futures).await.into_iter().collect();
180
        let mut results = results?;
181

182
183
        let (mut generations, next_batch, mut timings) =
            results.pop().ok_or(ClientError::EmptyResults)?;
184

185
186
187
188
189
190
191
192
193
        // Merge generations from different model shards
        for (mut shard_generations, _, shard_timings) in results.into_iter() {
            generations.append(&mut shard_generations);
            // Return the timings of the slowest shard
            if shard_timings.total > timings.total {
                timings = shard_timings;
            }
        }
        Ok((generations, next_batch, timings))
Olivier Dehaene's avatar
Olivier Dehaene committed
194
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
195
}
OlivierDehaene's avatar
OlivierDehaene committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

impl From<InfoResponse> for ShardInfo {
    fn from(value: InfoResponse) -> Self {
        Self {
            requires_padding: value.requires_padding,
            dtype: value.dtype,
            device_type: value.device_type,
            window_size: value.window_size,
            speculate: value.speculate,
        }
    }
}

#[async_trait]
impl Health for ShardedClient {
    async fn device_health(&self) -> Result<()> {
        self.clone().health().await?;
        Ok(())
    }

    async fn model_health(&self) -> Result<()> {
        // Dummy batch of 1 token and 1 generated token
        let liveness_request = Request {
            id: u64::MAX,
            inputs: "liveness".to_string(),
            input_chunks: Some(Input {
                chunks: vec![Chunk::Text("liveness".into()).into()],
            }),
            truncate: 10,
            prefill_logprobs: false,
            parameters: Some(NextTokenChooserParameters {
                temperature: 1.0,
                top_k: 0,
                top_p: 1.0,
                typical_p: 1.0,
                do_sample: false,
                seed: 0,
                repetition_penalty: 1.0,
                frequency_penalty: 0.0,
                watermark: false,
                grammar: String::new(),
                grammar_type: GrammarType::None as i32,
            }),
            stopping_parameters: Some(StoppingCriteriaParameters {
                max_new_tokens: 1,
                stop_sequences: vec![],
                ignore_eos_token: false,
            }),
            top_n_tokens: 0,
245
246
247
            // Block 0 is reserved for health checks
            blocks: vec![0],
            slots: (0..16).collect(),
248
            prefix_len: 0,
drbh's avatar
drbh committed
249
            adapter_id: None,
OlivierDehaene's avatar
OlivierDehaene committed
250
251
252
253
254
255
        };
        let batch = Batch {
            id: u64::MAX,
            requests: vec![liveness_request],
            size: 1,
            max_tokens: 2,
256
            max_blocks: 1,
OlivierDehaene's avatar
OlivierDehaene committed
257
258
259
260
261
        };
        self.clone().prefill(batch).await?;
        Ok(())
    }
}