sharded_client.rs 8.77 KB
Newer Older
1
use crate::client::Health;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
/// Multi shard Client
3
use crate::client::{ClientError, Result};
OlivierDehaene's avatar
OlivierDehaene committed
4

Nicolas Patry's avatar
Nicolas Patry committed
5
6
7
8
9
10
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::{
    Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
    NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use crate::client::{Chunk, InfoResponse, Input};
OlivierDehaene's avatar
OlivierDehaene committed
11
use async_trait::async_trait;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
12
13
use futures::future::join_all;
use tonic::transport::Uri;
14
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
15

16
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
17
/// Text Generation Inference gRPC multi client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18
pub struct ShardedClient {
19
    clients: Vec<Client>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
20
21
22
}

impl ShardedClient {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
23
    fn new(clients: Vec<Client>) -> Self {
24
        Self { clients }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25
26
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
27
28
    /// Create a new ShardedClient from a master client. The master client will communicate with
    /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
Olivier Dehaene's avatar
Olivier Dehaene committed
29
    async fn from_master_client(mut master_client: Client) -> Result<Self> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
        // Get all uris/unix sockets from the master client
31
        let uris = master_client.service_discovery().await?;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
32
        let futures = uris.into_iter().map(Client::connect_uds);
Olivier Dehaene's avatar
Olivier Dehaene committed
33
34
        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
        Ok(Self::new(clients?))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
35
36
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
    /// Returns a client connected to the given uri
Nicolas Patry's avatar
Nicolas Patry committed
38
    #[allow(dead_code)]
Olivier Dehaene's avatar
Olivier Dehaene committed
39
40
    pub async fn connect(uri: Uri) -> Result<Self> {
        let master_client = Client::connect(uri).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41
42
43
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
44
45
46
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
        let master_client = Client::connect_uds(path).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
47
48
49
        Self::from_master_client(master_client).await
    }

50
51
    /// Get the model info
    #[instrument(skip(self))]
52
    pub async fn info(&mut self) -> Result<InfoResponse> {
53
54
55
56
57
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.info())
            .collect();
58
        join_all(futures).await.pop().unwrap()
59
60
    }

61
62
63
64
65
66
67
68
69
70
71
    /// GRPC health check
    #[instrument(skip(self))]
    pub async fn health(&mut self) -> Result<HealthResponse> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.health())
            .collect();
        join_all(futures).await.pop().unwrap()
    }

72
    /// Clear the past generations cache
73
    #[instrument(skip(self))]
74
    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
75
76
77
        let futures: Vec<_> = self
            .clients
            .iter_mut()
78
            .map(|client| client.clear_cache(batch_id))
79
80
81
82
            .collect();
        join_all(futures).await.into_iter().collect()
    }

83
84
85
86
87
    /// Filter a cached batch
    #[instrument(skip(self))]
    pub async fn filter_batch(
        &mut self,
        batch_id: u64,
88
89
        request_ids: Vec<u64>,
    ) -> Result<Option<CachedBatch>> {
90
91
92
        let futures: Vec<_> = self
            .clients
            .iter_mut()
93
            .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
94
95
96
97
98
            .collect();
        // all shards return the same message
        join_all(futures).await.pop().unwrap()
    }

99
100
101
102
103
104
    /// Warmup on a max size batch
    ///
    /// Returns the maximum amount of tokens supported by the hardware
    #[instrument(skip(self))]
    pub async fn warmup(
        &mut self,
105
        max_input_length: Option<u32>,
106
        max_prefill_tokens: u32,
107
        max_total_tokens: Option<u32>,
108
        max_batch_size: Option<usize>,
109
    ) -> Result<(Option<u32>, u32, u32)> {
110
111
112
        let futures: Vec<_> = self
            .clients
            .iter_mut()
OlivierDehaene's avatar
OlivierDehaene committed
113
            .map(|client| {
114
115
116
117
118
119
                Box::pin(client.warmup(
                    max_input_length,
                    max_prefill_tokens,
                    max_total_tokens,
                    max_batch_size,
                ))
OlivierDehaene's avatar
OlivierDehaene committed
120
            })
121
            .collect();
122
123
124
        let results = join_all(futures)
            .await
            .into_iter()
125
126
127
128
129
130
131
132
133
134
            .collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;

        // Take the minimum value
        // Different shards hold different parts of vocab, might yield
        // different available block size.
        let min = results
            .iter()
            .min()
            .expect("Expect at least 1 warmup result");
        Ok(*min)
135
136
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
137
138
    /// Generate one token for each request in the given batch
    ///
139
    /// Returns Generation for each request in batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
140
    /// and the next cached batch
141
    #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
142
143
144
    pub async fn prefill(
        &mut self,
        batch: Batch,
145
        cached_batch: Option<CachedBatch>,
146
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
147
148
149
        let futures: Vec<_> = self
            .clients
            .iter_mut()
150
            .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
151
            .collect();
152
        #[allow(clippy::type_complexity)]
153
        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
154
            join_all(futures).await.into_iter().collect();
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        let mut results = results?;

        let (mut generations, next_batch, mut timings) =
            results.pop().ok_or(ClientError::EmptyResults)?;

        // Merge generations from different model shards
        for (mut shard_generations, _, shard_timings) in results.into_iter() {
            generations.append(&mut shard_generations);
            // Return the timings of the slowest shard
            if shard_timings.total > timings.total {
                timings = shard_timings;
            }
        }
        Ok((generations, next_batch, timings))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
169
170
    }

171
    /// Generate one token for each request in the given cached batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
172
    ///
173
    /// Returns Generation for each request in batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
174
    /// and the next cached batch
175
    #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
176
    pub async fn decode(
177
        &mut self,
178
        batches: Vec<CachedBatch>,
179
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
180
181
182
        let futures: Vec<_> = self
            .clients
            .iter_mut()
183
            .map(|client| Box::pin(client.decode(batches.clone())))
184
            .collect();
185
        #[allow(clippy::type_complexity)]
186
        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
187
            join_all(futures).await.into_iter().collect();
188
        let mut results = results?;
189

190
191
        let (mut generations, next_batch, mut timings) =
            results.pop().ok_or(ClientError::EmptyResults)?;
192

193
194
195
196
197
198
199
200
201
        // Merge generations from different model shards
        for (mut shard_generations, _, shard_timings) in results.into_iter() {
            generations.append(&mut shard_generations);
            // Return the timings of the slowest shard
            if shard_timings.total > timings.total {
                timings = shard_timings;
            }
        }
        Ok((generations, next_batch, timings))
Olivier Dehaene's avatar
Olivier Dehaene committed
202
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
203
}
OlivierDehaene's avatar
OlivierDehaene committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

#[async_trait]
impl Health for ShardedClient {
    async fn device_health(&self) -> Result<()> {
        self.clone().health().await?;
        Ok(())
    }

    async fn model_health(&self) -> Result<()> {
        // Dummy batch of 1 token and 1 generated token
        let liveness_request = Request {
            id: u64::MAX,
            inputs: "liveness".to_string(),
            input_chunks: Some(Input {
                chunks: vec![Chunk::Text("liveness".into()).into()],
            }),
            truncate: 10,
221
            add_special_tokens: true,
OlivierDehaene's avatar
OlivierDehaene committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            prefill_logprobs: false,
            parameters: Some(NextTokenChooserParameters {
                temperature: 1.0,
                top_k: 0,
                top_p: 1.0,
                typical_p: 1.0,
                do_sample: false,
                seed: 0,
                repetition_penalty: 1.0,
                frequency_penalty: 0.0,
                watermark: false,
                grammar: String::new(),
                grammar_type: GrammarType::None as i32,
            }),
            stopping_parameters: Some(StoppingCriteriaParameters {
                max_new_tokens: 1,
                stop_sequences: vec![],
                ignore_eos_token: false,
            }),
            top_n_tokens: 0,
242
243
244
            // Block 0 is reserved for health checks
            blocks: vec![0],
            slots: (0..16).collect(),
245
            cache_len: 0,
drbh's avatar
drbh committed
246
            adapter_id: None,
247
            chunk_len: None,
OlivierDehaene's avatar
OlivierDehaene committed
248
249
250
251
252
253
        };
        let batch = Batch {
            id: u64::MAX,
            requests: vec![liveness_request],
            size: 1,
            max_tokens: 2,
254
            max_blocks: 1,
OlivierDehaene's avatar
OlivierDehaene committed
255
        };
256
        self.clone().prefill(batch, None).await?;
OlivierDehaene's avatar
OlivierDehaene committed
257
258
259
        Ok(())
    }
}