grpc_client.rs 10.1 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Single shard Client
Nicolas Patry's avatar
Nicolas Patry committed
2
3
use crate::client::{pb, Chunk};
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
4
5
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
6
use grpc_metadata::InjectTelemetryContext;
OlivierDehaene's avatar
OlivierDehaene committed
7
8
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*;
9
use std::cmp::min;
10
use std::time::Duration;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
11
use tonic::transport::{Channel, Uri};
12
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
/// Text Generation Inference gRPC client
15
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
16
pub struct Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
17
    stub: TextGenerationServiceClient<Channel>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18
19
20
}

impl Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
21
    /// Returns a client connected to the given url
Nicolas Patry's avatar
Nicolas Patry committed
22
    #[allow(dead_code)]
Olivier Dehaene's avatar
Olivier Dehaene committed
23
24
    pub async fn connect(uri: Uri) -> Result<Self> {
        let channel = Channel::builder(uri).connect().await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25

Olivier Dehaene's avatar
Olivier Dehaene committed
26
27
28
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
29
30
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
31
32
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
Olivier Dehaene's avatar
Olivier Dehaene committed
33
        let channel = Channel::from_shared("http://[::]:50051".to_string())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
34
35
36
37
            .unwrap()
            .connect_with_connector(tower::service_fn(move |_: Uri| {
                tokio::net::UnixStream::connect(path.clone())
            }))
Olivier Dehaene's avatar
Olivier Dehaene committed
38
            .await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39

Olivier Dehaene's avatar
Olivier Dehaene committed
40
41
42
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
43
44
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
45
    /// Returns a list of uris or unix sockets of all shards
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
46
47
    #[instrument(skip(self))]
    pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
48
        let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
OlivierDehaene's avatar
OlivierDehaene committed
49
50
51
        let response = self.stub.service_discovery(request).await.map_err(|_| {
            ClientError::Connection("Server does not support v3 interface".to_string())
        })?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
52
53
54
55
        let urls = response
            .into_inner()
            .urls
            .into_iter()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
56
            // Remove unix socket prefix
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
57
58
59
60
61
62
63
64
            .map(|url| match url.strip_prefix("unix://") {
                None => url,
                Some(stripped_url) => stripped_url.to_string(),
            })
            .collect();
        Ok(urls)
    }

65
66
67
68
69
70
71
72
    /// Get model info
    #[instrument(skip(self))]
    pub async fn info(&mut self) -> Result<InfoResponse> {
        let request = tonic::Request::new(InfoRequest {}).inject_context();
        let response = self.stub.info(request).await?.into_inner();
        Ok(response)
    }

73
74
75
76
77
78
79
80
    /// Get model health
    #[instrument(skip(self))]
    pub async fn health(&mut self) -> Result<HealthResponse> {
        let request = tonic::Request::new(HealthRequest {}).inject_context();
        let response = self.stub.health(request).await?.into_inner();
        Ok(response)
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
81
    /// Clear the past generations cache
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
82
    #[instrument(skip(self))]
83
84
    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
        let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
85
        self.stub.clear_cache(request).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
86
87
88
        Ok(())
    }

89
90
91
92
93
    /// Filter a cached batch
    #[instrument(skip(self))]
    pub async fn filter_batch(
        &mut self,
        batch_id: u64,
94
95
        request_ids: Vec<u64>,
    ) -> Result<Option<CachedBatch>> {
96
97
        let request = tonic::Request::new(FilterBatchRequest {
            batch_id,
98
            request_ids,
99
100
101
102
103
104
        })
        .inject_context();
        let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
        Ok(filtered_batch.batch)
    }

105
106
107
    /// Warmup on a max size batch
    ///
    /// Returns the maximum amount of tokens supported by the hardware
108
    #[instrument(skip_all)]
109
110
111
112
    pub async fn warmup(
        &mut self,
        max_input_length: u32,
        max_prefill_tokens: u32,
OlivierDehaene's avatar
OlivierDehaene committed
113
        max_total_tokens: u32,
114
        max_batch_size: Option<usize>,
115
    ) -> Result<Option<u32>> {
116
117
118
119
        let mut n_tokens = 0;
        let mut requests = Vec::new();
        // Create requests
        while n_tokens < max_prefill_tokens {
OlivierDehaene's avatar
OlivierDehaene committed
120
            let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
121

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
            let mut input_chunks = Vec::new();
            input_chunks
                .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
            if n_tokens == 0 {
                input_chunks.push(
                    Chunk::Image(Image {
                        // Safe unwrap, because we control the data.
                        data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
                        mimetype: "image/jpeg;base64".to_string(),
                    })
                    .into(),
                );
            }

            // Send stringly-typed inputs for compatibility for backends that haven't
            // been updated to support chunks.
OlivierDehaene's avatar
OlivierDehaene committed
138

139
140
            let mut inputs = String::new();
            inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
Nicolas Patry's avatar
Nicolas Patry committed
141
142
143
            if n_tokens == 0 {
                // 1 request is enough to test vision heads.
                // Sending images on other queries messes up easily with truncation.
144
145
146
                inputs.push_str(&format!(
                    "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
                ));
Nicolas Patry's avatar
Nicolas Patry committed
147
            }
148

149
150
            requests.push(Request {
                id: 0,
OlivierDehaene's avatar
OlivierDehaene committed
151
                inputs,
152
                add_special_tokens: true,
153
154
155
156
                input_chunks: Some(Input {
                    chunks: input_chunks,
                }),
                // We truncate the input on the server side to be sure that it has the correct size
OlivierDehaene's avatar
OlivierDehaene committed
157
                truncate,
158
159
160
                // Blocks and slots will be set on the server side if we use paged attention
                blocks: vec![],
                slots: vec![],
161
162
                cache_len: 0,
                chunk_len: None,
163
164
165
166
167
168
169
170
171
                // Set sampling parameters to also take these ops into account in the max memory
                parameters: Some(NextTokenChooserParameters {
                    temperature: 0.9,
                    top_k: 10,
                    top_p: 0.9,
                    typical_p: 0.9,
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 1.2,
172
                    frequency_penalty: 0.1,
173
                    watermark: true,
drbh's avatar
drbh committed
174
175
                    grammar: String::new(),
                    grammar_type: GrammarType::None as i32,
176
177
                }),
                stopping_parameters: Some(StoppingCriteriaParameters {
OlivierDehaene's avatar
OlivierDehaene committed
178
                    max_new_tokens: max_total_tokens - truncate,
179
                    stop_sequences: vec![],
OlivierDehaene's avatar
OlivierDehaene committed
180
                    ignore_eos_token: true,
181
182
                }),
                prefill_logprobs: true,
Nicolas Patry's avatar
Nicolas Patry committed
183
                top_n_tokens: 20,
drbh's avatar
drbh committed
184
                adapter_id: None,
185
186
            });
            n_tokens += max_input_length;
187
188
189
190
191

            // Check max_batch_size
            if Some(requests.len()) == max_batch_size {
                break;
            }
192
193
194
195
196
197
        }

        let batch = Batch {
            id: 0,
            size: requests.len() as u32,
            requests,
198
199
            max_tokens: max_input_length,
            max_blocks: 0,
200
201
        };

202
203
204
205
206
207
208
        let request = tonic::Request::new(WarmupRequest {
            batch: Some(batch),
            max_input_length,
            max_prefill_tokens,
            max_total_tokens,
        })
        .inject_context();
209
210
        let response = self.stub.warmup(request).await?.into_inner();
        Ok(response.max_supported_total_tokens)
211
212
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
213
214
    /// Generate one token for each request in the given batch
    ///
215
    /// Returns Generation for each request in batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
216
    /// and the next cached batch
217
    #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
218
219
220
    pub async fn prefill(
        &mut self,
        batch: Batch,
221
        cached_batch: Option<CachedBatch>,
222
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
223
224
225
226
227
        let request = tonic::Request::new(PrefillRequest {
            batch: Some(batch),
            cached_batch,
        })
        .inject_context();
228
        let response = self.stub.prefill(request).await?.into_inner();
229
230
231
        Ok((
            response.generations,
            response.batch,
232
233
234
235
236
237
            PrefillTimings::new(
                response.concat_ns,
                response.forward_ns,
                response.decode_ns,
                response.total_ns,
            ),
238
        ))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
239
240
    }

241
    /// Generate one token for each request in the given cached batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
242
    ///
243
    /// Returns Generation for each request in batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
244
    /// and the next cached batch
245
    #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
246
    pub async fn decode(
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
247
        &mut self,
248
        batches: Vec<CachedBatch>,
249
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
250
251
        let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
        let response = self.stub.decode(request).await?.into_inner();
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        Ok((
            response.generations,
            response.batch,
            DecodeTimings::new(
                response.concat_ns,
                response.forward_ns,
                response.decode_ns,
                response.total_ns,
            ),
        ))
    }
}

pub struct PrefillTimings {
266
    pub concat: Option<Duration>,
267
268
269
270
271
272
    pub forward: Duration,
    pub decode: Duration,
    pub total: Duration,
}

impl PrefillTimings {
273
    fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
274
        Self {
275
            concat: concat_ns.map(Duration::from_nanos),
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            forward: Duration::from_nanos(forward_ns),
            decode: Duration::from_nanos(decode_ns),
            total: Duration::from_nanos(total_ns),
        }
    }
}

pub struct DecodeTimings {
    pub concat: Option<Duration>,
    pub forward: Duration,
    pub decode: Duration,
    pub total: Duration,
}

impl DecodeTimings {
    fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
        Self {
293
            concat: concat_ns.map(Duration::from_nanos),
294
295
296
297
            forward: Duration::from_nanos(forward_ns),
            decode: Duration::from_nanos(decode_ns),
            total: Duration::from_nanos(total_ns),
        }
Olivier Dehaene's avatar
Olivier Dehaene committed
298
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
299
}