"rust/vscode:/vscode.git/clone" did not exist on "bbb81c24578cddcde0f6241ffa993ab471b6b214"
validation.rs 36.4 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Payload validation logic
OlivierDehaene's avatar
OlivierDehaene committed
2
use crate::config::Config;
3
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
drbh's avatar
drbh committed
4
use crate::{GenerateParameters, GenerateRequest, GrammarType};
OlivierDehaene's avatar
OlivierDehaene committed
5
6
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat};
7
use jsonschema::{Draft, JSONSchema};
8
use rand::{thread_rng, Rng};
9
use serde_json::Value;
10
use std::io::Cursor;
OlivierDehaene's avatar
OlivierDehaene committed
11
use text_generation_client::{Chunk, Image, InputChunk};
Olivier Dehaene's avatar
Olivier Dehaene committed
12
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
13
use tokenizers::tokenizer::Tokenizer;
OlivierDehaene's avatar
OlivierDehaene committed
14
use tokio::sync::mpsc;
15
use tokio::sync::oneshot;
16
use tracing::{instrument, Span};
17
use {once_cell::sync::Lazy, regex::Regex};
Olivier Dehaene's avatar
Olivier Dehaene committed
18

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
20
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
21
pub struct Validation {
22
    /// Validation parameters
23
    max_best_of: usize,
24
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
25
    max_top_n_tokens: u32,
26
27
    max_input_length: usize,
    max_total_tokens: usize,
drbh's avatar
drbh committed
28
    disable_grammar_support: bool,
29
    /// Channel to communicate with the background tokenization task
OlivierDehaene's avatar
OlivierDehaene committed
30
    sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
31
32
33
}

impl Validation {
OlivierDehaene's avatar
OlivierDehaene committed
34
    #[allow(clippy::too_many_arguments)]
35
36
    pub(crate) fn new(
        workers: usize,
37
        tokenizer: Option<Tokenizer>,
38
        config: Option<Config>,
39
        max_best_of: usize,
40
        max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
41
        max_top_n_tokens: u32,
42
43
        max_input_length: usize,
        max_total_tokens: usize,
drbh's avatar
drbh committed
44
        disable_grammar_support: bool,
45
    ) -> Self {
46
47
        // If we have a fast tokenizer
        let sender = if let Some(tokenizer) = tokenizer {
OlivierDehaene's avatar
OlivierDehaene committed
48
49
50
            // Create round robin channel
            let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
            let mut senders = Vec::with_capacity(workers);
51
52
53
54

            // Create workers
            for _ in 0..workers {
                let tokenizer_clone = tokenizer.clone();
55
                let config_clone = config.clone();
OlivierDehaene's avatar
OlivierDehaene committed
56
57
                let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
                senders.push(tokenizer_sender);
58
59
60

                // Spawn worker
                tokio::task::spawn_blocking(move || {
61
                    tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver)
62
63
                });
            }
OlivierDehaene's avatar
OlivierDehaene committed
64
65
66
67

            // Create tokenization round robin task
            tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));

68
69
70
71
72
73
74
75
            Some(validation_sender)
        } else {
            None
        };

        Self {
            max_best_of,
            sender,
76
            max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
77
            max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
78
            max_input_length,
79
            max_total_tokens,
drbh's avatar
drbh committed
80
            disable_grammar_support,
81
82
        }
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
83

84
    #[instrument(skip(self, inputs))]
85
    pub async fn tokenize(
86
87
88
        &self,
        inputs: String,
        truncate: Option<usize>,
89
    ) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
90
91
92
93
94
95
96
97
98
99
100
101
        // If we have a fast tokenizer
        if let Some(sender) = &self.sender {
            // Create response channel
            let (response_sender, response_receiver) = oneshot::channel();
            // Send request to the background validation task
            // Unwrap is safe here
            sender
                .send(((inputs, truncate), response_sender, Span::current()))
                .unwrap();

            // Await on response channel
            // Unwrap is safe here
102
103
104
105
106
107
108
109
110
111
112
113
114
            let encoding = response_receiver.await.unwrap()?;
            Ok(Some(encoding))
        } else {
            Ok(None)
        }
    }

    #[instrument(skip(self, inputs))]
    async fn validate_input(
        &self,
        inputs: String,
        truncate: Option<usize>,
        max_new_tokens: Option<u32>,
115
    ) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
116
117
118
        // If we have a fast tokenizer
        if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
            // Create response channel
Nicolas Patry's avatar
Nicolas Patry committed
119
120
121
122
123
            let input_length = if let Some(truncate) = truncate {
                std::cmp::min(encoding.len(), truncate)
            } else {
                encoding.len()
            };
124
125

            // Get total tokens
126
127
128
129
130
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
                self.max_total_tokens.saturating_sub(input_length) as u32
            };
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
            let total_tokens = input_length + max_new_tokens as usize;

            // Validate MaxTotalTokens
            if total_tokens > self.max_total_tokens {
                return Err(ValidationError::MaxTotalTokens(
                    self.max_total_tokens,
                    input_length,
                    max_new_tokens,
                ));
            }

            // Validate InputLength
            if input_length > self.max_input_length {
                return Err(ValidationError::InputLength(
                    self.max_input_length,
                    input_length,
                ));
            }

            metrics::histogram!("tgi_request_input_length", input_length as f64);
151
            Ok((inputs, input_length, max_new_tokens))
152
153
154
155
156
157
        }
        // Return inputs without validation
        else {
            // In this case, we don't know the real length in tokens of the inputs
            // However, the inputs will be truncated by the python servers
            // We make sure that truncate + max_new_tokens <= self.max_total_tokens
158
159
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
OlivierDehaene's avatar
OlivierDehaene committed
160
161
            } else if let Some(truncate) = truncate {
                self.max_total_tokens.saturating_sub(truncate) as u32
162
            } else {
OlivierDehaene's avatar
OlivierDehaene committed
163
                return Err(ValidationError::UnsetMaxNewTokens);
164
            };
165
            let mut input_length = truncate.unwrap_or(self.max_input_length);
166

167
168
            // We don't have a tokenizer, therefore we have no idea how long is the query, let
            // them through and hope for the best.
169
            // Validate MaxNewTokens
170
            if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
171
                input_length = input_length.saturating_sub(max_new_tokens as usize);
172
173
            }

174
175
176
177
178
            Ok((
                vec![Chunk::Text(inputs).into()],
                input_length,
                max_new_tokens,
            ))
Olivier Dehaene's avatar
Olivier Dehaene committed
179
180
181
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
182
    /// Validate a payload and get the number of tokens in the input
183
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
184
185
186
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
187
    ) -> Result<ValidGenerateRequest, ValidationError> {
188
189
190
191
        let GenerateParameters {
            best_of,
            temperature,
            repetition_penalty,
192
            frequency_penalty,
193
194
195
196
197
198
199
200
201
            top_k,
            top_p,
            typical_p,
            do_sample,
            max_new_tokens,
            stop: stop_sequences,
            truncate,
            seed,
            watermark,
202
            decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
203
            top_n_tokens,
drbh's avatar
drbh committed
204
            grammar,
drbh's avatar
drbh committed
205
            adapter_id,
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
            ..
        } = request.parameters;

        // sampling must be true when best_of > 1
        let best_of = best_of.unwrap_or(1);
        let sampling = do_sample
            || temperature.is_some()
            || top_k.is_some()
            || top_p.is_some()
            || typical_p.is_some();

        if best_of > 1 && !sampling {
            return Err(BestOfSampling);
        }

        let temperature = temperature.unwrap_or(1.0);
        if temperature <= 0.0 {
            return Err(ValidationError::Temperature);
        }

        let repetition_penalty = repetition_penalty.unwrap_or(1.0);
        if repetition_penalty <= 0.0 {
            return Err(ValidationError::RepetitionPenalty);
        }

231
232
233
234
235
        let frequency_penalty = frequency_penalty.unwrap_or(0.0);
        if !(-2.0..=2.0).contains(&frequency_penalty) {
            return Err(ValidationError::FrequencyPenalty);
        }

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        // Different because the proto default value is not a valid value
        // for the user
        let top_p = top_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TopP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let typical_p = typical_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TypicalP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let top_k: u32 = top_k
            .map(|value| {
                if value <= 0 {
                    return Err(ValidationError::TopK);
                }
                Ok(value as u32)
            })
            .unwrap_or(Ok(0))?;

265
        if max_new_tokens == Some(0) {
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
            return Err(ValidationError::NegativeMaxNewTokens);
        }

        if stop_sequences.len() > self.max_stop_sequences {
            return Err(ValidationError::StopSequence(
                self.max_stop_sequences,
                stop_sequences.len(),
            ));
        }

        // If seed is None, assign a random one
        let seed = match seed {
            None => thread_rng().gen(),
            Some(seed) => {
                if best_of > 1 {
                    return Err(BestOfSeed);
                }
                seed
            }
        };

Nicolas Patry's avatar
Nicolas Patry committed
287
288
289
290
291
292
293
294
295
        let top_n_tokens = top_n_tokens
            .map(|value| {
                if value > self.max_top_n_tokens {
                    return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
                }
                Ok(value)
            })
            .unwrap_or(Ok(0))?;

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        // Check if inputs is empty
        if request.inputs.is_empty() {
            return Err(EmptyInput);
        }

        // Check if truncate is strictly positive and less than max_input_length
        let truncate = truncate
            .map(|value| {
                if value == 0 || value > self.max_input_length {
                    return Err(ValidationError::Truncate(self.max_input_length, value));
                }
                Ok(Some(value))
            })
            .unwrap_or(Ok(None))?;

        // Validate inputs
312
        let (inputs, input_length, max_new_tokens) = self
313
314
315
            .validate_input(request.inputs, truncate, max_new_tokens)
            .await?;

drbh's avatar
drbh committed
316
317
318
319
320
321
322
        // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
        // NOTE: this is currently difficult because we need the tokenizer in Python to build
        // the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which
        // may be slow and memory intensive. Best case is to have a Rust implementation of the FSM
        // compiler and use that to build the FSM here.

        // Validate grammar and unpack the grammar and type for the proto message
OlivierDehaene's avatar
OlivierDehaene committed
323
        let grammar = match grammar {
drbh's avatar
drbh committed
324
325
326
327
328
            Some(grammar) => {
                // Ensure that grammar is not set if it's not supported
                if self.disable_grammar_support {
                    return Err(ValidationError::Grammar);
                }
OlivierDehaene's avatar
OlivierDehaene committed
329
                let valid_grammar = match grammar {
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
                    GrammarType::Json(json) => {
                        let json = match json {
                            // if value is a string, we need to parse it again to make sure its
                            // a valid json
                            Value::String(s) => serde_json::from_str(&s)
                                .map_err(|e| ValidationError::InvalidGrammar(e.to_string())),
                            Value::Object(_) => Ok(json),
                            _ => Err(ValidationError::Grammar),
                        }?;

                        // Check if the json is a valid JSONSchema
                        JSONSchema::options()
                            .with_draft(Draft::Draft202012)
                            .compile(&json)
                            .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;

OlivierDehaene's avatar
OlivierDehaene committed
346
347
                        // Serialize json to string
                        ValidGrammar::Json(
348
349
350
351
                            serde_json::to_string(&json)
                                .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
                        )
                    }
OlivierDehaene's avatar
OlivierDehaene committed
352
353
354
                    GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
                };
                Some(valid_grammar)
drbh's avatar
drbh committed
355
            }
OlivierDehaene's avatar
OlivierDehaene committed
356
            None => None,
drbh's avatar
drbh committed
357
358
        };

OlivierDehaene's avatar
OlivierDehaene committed
359
        let parameters = ValidParameters {
360
361
            temperature,
            repetition_penalty,
362
            frequency_penalty,
363
364
365
366
367
368
            top_k,
            top_p,
            typical_p,
            do_sample,
            seed,
            watermark,
drbh's avatar
drbh committed
369
            grammar,
370
        };
OlivierDehaene's avatar
OlivierDehaene committed
371
        let stopping_parameters = ValidStoppingParameters {
372
373
374
375
376
377
378
379
380
            max_new_tokens,
            stop_sequences,
            ignore_eos_token: false,
        };

        metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

        Ok(ValidGenerateRequest {
            inputs,
381
            decoder_input_details,
382
            input_length: input_length as u32,
383
384
385
            truncate: truncate.unwrap_or(self.max_input_length) as u32,
            parameters,
            stopping_parameters,
Nicolas Patry's avatar
Nicolas Patry committed
386
            top_n_tokens,
drbh's avatar
drbh committed
387
            adapter_id,
388
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
389
    }
390
391
392
393
394
395
396
397
398
399
400
401
402
403

    /// Validate the best_of parameter
    #[instrument(skip_all)]
    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
        if self.max_best_of == 1 && best_of != 1 {
            return Err(ValidationError::BestOfDisabled);
        }

        if best_of > self.max_best_of {
            return Err(ValidationError::BestOf(self.max_best_of, best_of));
        }

        Ok(best_of)
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
404
405
}

OlivierDehaene's avatar
OlivierDehaene committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
/// Round robin tokenization task
async fn round_robin_task(
    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
    senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
) {
    loop {
        for sender in &senders {
            match receiver.recv().await {
                None => return,
                Some(request) => sender.send(request).unwrap(),
            };
        }
    }
}

421
/// Start tokenization workers
422
423
424
425
426
fn tokenizer_worker(
    tokenizer: Tokenizer,
    config: Option<Config>,
    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
427
    // Loop over requests
OlivierDehaene's avatar
OlivierDehaene committed
428
    while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
429
430
        parent_span.in_scope(|| {
            response_tx
431
                .send(prepare_input(inputs, truncate, &tokenizer, &config))
432
433
                .unwrap_or(())
        })
434
435
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
436

437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
    match mimetype {
        "image/png" => Some(ImageFormat::Png),
        "image/jpeg" => Some(ImageFormat::Jpeg),
        "image/jpg" => Some(ImageFormat::Jpeg),
        "image/gif" => Some(ImageFormat::Gif),
        "image/webp" => Some(ImageFormat::WebP),
        "image/tiff" => Some(ImageFormat::Tiff),
        // "image/pnm"=>Some(ImageFormat::Pnm),
        // "image/tga"=>Some(ImageFormat::Tga),
        // "image/dds"=>Some(ImageFormat::Dds),
        // "image/bmp"=>Some(ImageFormat::Bmp),
        // "image/ico"=>Some(ImageFormat::Ico),
        // "image/x-exr"=>Some(ImageFormat::OpenExr),
        _ => None,
    }
}
OlivierDehaene's avatar
OlivierDehaene committed
454

455
456
457
458
459
460
461
462
463
464
465
466
fn format_to_mimetype(format: ImageFormat) -> String {
    match format {
        ImageFormat::Png => "image/png",
        ImageFormat::Jpeg => "image/jpeg",
        ImageFormat::Gif => "image/gif",
        ImageFormat::WebP => "image/webp",
        ImageFormat::Tiff => "image/tiff",
        _ => "application/octet-stream",
    }
    .to_string()
}

467
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
468
469
470
471
472
473
474
475
476
477
    if input.starts_with("![](http://") || input.starts_with("![](https://") {
        let url = &input["![](".len()..input.len() - 1];
        let data = reqwest::blocking::get(url)?.bytes()?;

        let format = image::guess_format(&data)?;
        // TODO Remove this clone
        let img = ImageReader::with_format(Cursor::new(data.clone()), format).decode()?;
        let height: usize = img.height().try_into()?;
        let width: usize = img.width().try_into()?;
        let mimetype = format_to_mimetype(format);
478
        Ok((data.to_vec(), mimetype, height, width))
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    } else if input.starts_with("![](data:") {
        // Remove ![](....)
        let content = &input["![](data:".len()..input.len() - 1];
        let tokens: Vec<_> = content.split(';').collect();
        if tokens.len() != 2 {
            return Err(ValidationError::InvalidImageContent(content.to_string()));
        }
        let mimetype = tokens[0];
        let content = tokens[1];

        if !content.starts_with("base64,") {
            return Err(ValidationError::InvalidImageContent(content.to_string()));
        }

        let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
        let img = if let Some(format) = format_from_mimetype(mimetype) {
495
            ImageReader::with_format(Cursor::new(&data), format).decode()?
496
        } else {
497
            ImageReader::new(Cursor::new(&data))
498
499
500
501
502
503
504
                .with_guessed_format()
                .map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
                .decode()?
        };

        let height: usize = img.height().try_into()?;
        let width: usize = img.width().try_into()?;
505
        Ok((data, mimetype.to_string(), height, width))
506
507
508
509
510
    } else {
        Err(ValidationError::InvalidImageContent(input.to_string()))
    }
}

511
512
/// Get input length and optionally truncate it
fn prepare_input(
513
    inputs: String,
514
    _truncate: Option<usize>,
515
    tokenizer: &Tokenizer,
516
    config: &Option<Config>,
517
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
518
    static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
519
    let (tokenizer_query, input_chunks) = match config {
520
        Some(Config::LlavaNext(config)) => {
521
            let mut input_chunks = Vec::new();
522
523
524
525
            let mut tokenizer_query = String::with_capacity(inputs.len());
            let mut start = 0;
            for chunk in RE.find_iter(&inputs) {
                let chunk_start = chunk.start();
drbh's avatar
drbh committed
526
527
                let chunk_end = chunk.end();
                if chunk_start != start {
528
                    input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
drbh's avatar
drbh committed
529
530
                    tokenizer_query.push_str(&inputs[start..chunk_start]);
                }
531
                let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
drbh's avatar
drbh committed
532
                let slots = config.get_number_of_features(height, width);
533
                input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
drbh's avatar
drbh committed
534
535
536
                tokenizer_query.push_str(&"<image>".repeat(slots));
                start = chunk_end;
            }
537
538
            if start != inputs.len() {
                input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
drbh's avatar
drbh committed
539
540
                tokenizer_query.push_str(&inputs[start..]);
            }
541
            (tokenizer_query, input_chunks)
drbh's avatar
drbh committed
542
543
        }
        Some(Config::Paligemma(config)) => {
544
            let mut input_chunks = Vec::new();
drbh's avatar
drbh committed
545
546
547
548
            let mut tokenizer_query = String::with_capacity(inputs.len());
            let mut start = 0;
            for chunk in RE.find_iter(&inputs) {
                let chunk_start = chunk.start();
549
550
                let chunk_end = chunk.end();
                if chunk_start != start {
551
                    input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
552
553
                    tokenizer_query.push_str(&inputs[start..chunk_start]);
                }
554
                let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
555
                let slots = config.get_number_of_features(height, width);
556
                input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
557
558
559
                tokenizer_query.push_str(&"<image>".repeat(slots));
                start = chunk_end;
            }
560
561
            if start != inputs.len() {
                input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
562
563
                tokenizer_query.push_str(&inputs[start..]);
            }
564
            (tokenizer_query, input_chunks)
565
        }
Nicolas Patry's avatar
Nicolas Patry committed
566
        Some(Config::Idefics2(config)) => {
567
            let mut input_chunks = Vec::new();
Nicolas Patry's avatar
Nicolas Patry committed
568
569
570
571
572
573
            let mut tokenizer_query = String::with_capacity(inputs.len());
            let mut start = 0;
            for chunk in RE.find_iter(&inputs) {
                let chunk_start = chunk.start();
                let chunk_end = chunk.end();
                if chunk_start != start {
574
                    input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
Nicolas Patry's avatar
Nicolas Patry committed
575
576
                    tokenizer_query.push_str(&inputs[start..chunk_start]);
                }
577
                let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
Nicolas Patry's avatar
Nicolas Patry committed
578
579
580
581
582
                let slots = config.get_number_of_features(height, width);
                tokenizer_query.push_str("<fake_token_around_image>");
                tokenizer_query.push_str(&"<image>".repeat(slots));
                tokenizer_query.push_str("<fake_token_around_image>");

583
                input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
Nicolas Patry's avatar
Nicolas Patry committed
584
585
                start = chunk_end;
            }
586
587
            if start != inputs.len() {
                input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
Nicolas Patry's avatar
Nicolas Patry committed
588
589
                tokenizer_query.push_str(&inputs[start..]);
            }
590
            (tokenizer_query, input_chunks)
Nicolas Patry's avatar
Nicolas Patry committed
591
592
        }
        Some(Config::Idefics) => {
593
            let mut input_chunks = Vec::new();
Nicolas Patry's avatar
Nicolas Patry committed
594
595
596
597
598
599
            let mut tokenizer_query = String::with_capacity(inputs.len());
            let mut start = 0;
            for chunk in RE.find_iter(&inputs) {
                let chunk_start = chunk.start();
                let chunk_end = chunk.end();
                if chunk_start != start {
600
                    input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
Nicolas Patry's avatar
Nicolas Patry committed
601
602
                    tokenizer_query.push_str(&inputs[start..chunk_start]);
                }
603
604
                let (data, mimetype, _height, _width) =
                    fetch_image(&inputs[chunk_start..chunk_end])?;
Nicolas Patry's avatar
Nicolas Patry committed
605
606
                let slots = 1;
                tokenizer_query.push_str(&"<image>".repeat(slots));
607
                input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
Nicolas Patry's avatar
Nicolas Patry committed
608
609
                start = chunk_end;
            }
610
611
            if start != inputs.len() {
                input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
Nicolas Patry's avatar
Nicolas Patry committed
612
613
                tokenizer_query.push_str(&inputs[start..]);
            }
614
            (tokenizer_query, input_chunks)
Nicolas Patry's avatar
Nicolas Patry committed
615
        }
616
        _ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
617
    };
618

619
    // Get the number of tokens in the input
620
621
    let encoding = tokenizer
        .encode(tokenizer_query, true)
622
623
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

624
    Ok((encoding, input_chunks))
Olivier Dehaene's avatar
Olivier Dehaene committed
625
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
626

627
628
type TokenizerRequest = (
    (String, Option<usize>),
629
    oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
630
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
631
632
);

OlivierDehaene's avatar
OlivierDehaene committed
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
#[derive(Debug, Clone)]
pub(crate) enum ValidGrammar {
    Json(String),
    Regex(String),
}

#[derive(Debug, Clone)]
pub(crate) struct ValidParameters {
    /// / exponential scaling output probability distribution
    pub temperature: f32,
    /// / restricting to the k highest probability elements
    pub top_k: u32,
    /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
    pub top_p: f32,
    /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
    pub typical_p: f32,
    /// / apply sampling on the logits
    pub do_sample: bool,
    /// / random seed for sampling
    pub seed: u64,
    /// / repetition penalty
    pub repetition_penalty: f32,
    /// / frequency penalty
    pub frequency_penalty: f32,
    /// / token watermarking using "A Watermark for Large Language Models"
    pub watermark: bool,
    /// / grammar (applied if not empty)
    pub grammar: Option<ValidGrammar>,
}

#[derive(Debug, Clone)]
pub(crate) struct ValidStoppingParameters {
    /// / Maximum number of generated tokens
    pub max_new_tokens: u32,
    /// / Optional stopping sequences
    pub stop_sequences: Vec<String>,
    /// / Ignore end of sequence token
    /// / used for benchmarking
    pub ignore_eos_token: bool,
}

674
#[derive(Debug, Clone)]
675
pub(crate) struct ValidGenerateRequest {
676
    pub inputs: Vec<InputChunk>,
677
    pub input_length: u32,
678
    pub truncate: u32,
679
    pub decoder_input_details: bool,
OlivierDehaene's avatar
OlivierDehaene committed
680
681
    pub parameters: ValidParameters,
    pub stopping_parameters: ValidStoppingParameters,
Nicolas Patry's avatar
Nicolas Patry committed
682
    pub top_n_tokens: u32,
drbh's avatar
drbh committed
683
    pub adapter_id: Option<String>,
684
685
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
686
687
#[derive(Error, Debug)]
pub enum ValidationError {
688
689
690
691
692
693
694
695
696
697
    #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
    BestOf(usize, usize),
    #[error("`best_of` != 1 is not allowed for this endpoint")]
    BestOfDisabled,
    #[error("you must use sampling when `best_of` is > 1")]
    BestOfSampling,
    #[error("`seed` must not be set when `best_of` > 1")]
    BestOfSeed,
    #[error("`best_of` != 1 is not supported when streaming tokens")]
    BestOfStream,
Nicolas Patry's avatar
Nicolas Patry committed
698
699
700
701
    #[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
    TopNTokens(u32, u32),
    #[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
    TopNTokensDisabled,
702
703
    #[error("`decoder_input_details` == true is not supported when streaming tokens")]
    PrefillDetailsStream,
704
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
705
    Temperature,
706
    #[error("`repetition_penalty` must be strictly positive")]
707
    RepetitionPenalty,
708
709
    #[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
    FrequencyPenalty,
710
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
711
    TopP,
712
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
713
    TopK,
714
715
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
716
717
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
718
719
    #[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
    UnsetMaxNewTokens,
720
    #[error("`max_new_tokens` must be strictly positive")]
721
722
723
    NegativeMaxNewTokens,
    #[error("`max_new_tokens` must be <= {0}. Given: {1}")]
    MaxNewTokens(usize, u32),
724
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
725
    MaxTotalTokens(usize, usize, u32),
726
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
727
    InputLength(usize, usize),
728
    #[error("`inputs` cannot be empty")]
729
    EmptyInput,
730
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
731
    StopSequence(usize, usize),
732
733
    #[error("tokenizer error {0}")]
    Tokenizer(String),
drbh's avatar
drbh committed
734
735
    #[error("grammar is not supported")]
    Grammar,
736
737
    #[error("grammar is not valid: {0}")]
    InvalidGrammar(String),
738
739
740
741
742
743
744
745
746
747
    #[error("base64 encoding is invalid: {0}")]
    InvalidBase64(#[from] base64::DecodeError),
    #[error("invalid image: {0}")]
    InvalidImage(#[from] image::ImageError),
    #[error("invalid integer: {0}")]
    InvalidInt(#[from] core::num::TryFromIntError),
    #[error("invalid image content: {0}")]
    InvalidImageContent(String),
    #[error("Could not fetch image: {0}")]
    FailedFetchImage(#[from] reqwest::Error),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
748
}
749
750

#[cfg(test)]
751
mod tests {
752
    use super::*;
753
    use crate::config::{PaliTextConfig, Paligemma};
754
755
    use crate::default_parameters;
    use crate::tests::get_tokenizer;
756
757

    #[tokio::test]
758
    async fn test_validation_max_new_tokens() {
759
760
761
        let tokenizer = None;
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
762
763
764
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
765
        let workers = 1;
drbh's avatar
drbh committed
766
        let disable_grammar_support = true;
767
        let config = None;
768
769
770
        let validation = Validation::new(
            workers,
            tokenizer,
771
            config,
772
773
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
774
            max_top_n_tokens,
775
776
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
777
            disable_grammar_support,
778
        );
779
780

        let max_new_tokens = 10;
781
        match validation
782
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
783
784
            .await
        {
785
786
787
            // Err(ValidationError::MaxNewTokens(1, 10)) => (),
            Ok((_s, 0, 10)) => (),
            r => panic!("Unexpected not max new tokens: {r:?}"),
788
789
790
791
        }
    }

    #[tokio::test]
792
    async fn test_validation_input_length() {
793
794
795
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
796
797
798
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
drbh's avatar
drbh committed
799
        let disable_grammar_support = true;
800
        let workers = 1;
801
        let config = None;
802
803
804
        let validation = Validation::new(
            workers,
            tokenizer,
805
            config,
806
807
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
808
            max_top_n_tokens,
809
810
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
811
            disable_grammar_support,
812
        );
813
814

        let max_new_tokens = 10;
815
        match validation
816
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
817
818
            .await
        {
Nicolas Patry's avatar
Nicolas Patry committed
819
            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
820
            _ => panic!("Unexpected not max new tokens"),
821
822
        }
    }
823
824

    #[tokio::test]
825
    async fn test_validation_best_of_sampling() {
826
827
828
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
829
830
831
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
832
        let workers = 1;
drbh's avatar
drbh committed
833
        let disable_grammar_support = true;
834
        let config = None;
835
836
837
        let validation = Validation::new(
            workers,
            tokenizer,
838
            config,
839
840
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
841
            max_top_n_tokens,
842
843
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
844
            disable_grammar_support,
845
846
847
848
849
850
851
852
853
854
855
856
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    best_of: Some(2),
                    do_sample: false,
                    ..default_parameters()
                },
            })
            .await
        {
857
            Err(ValidationError::BestOfSampling) => (),
858
            _ => panic!("Unexpected not best of sampling"),
859
860
861
862
        }
    }

    #[tokio::test]
863
    async fn test_validation_top_p() {
864
865
866
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
867
868
        let max_top_n_tokens = 4;
        let max_input_length = 5;
869
        let max_total_tokens = 106;
870
        let workers = 1;
drbh's avatar
drbh committed
871
        let disable_grammar_support = true;
872
        let config = None;
873
874
875
        let validation = Validation::new(
            workers,
            tokenizer,
876
            config,
877
878
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
879
            max_top_n_tokens,
880
881
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
882
            disable_grammar_support,
883
884
885
886
887
888
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(1.0),
889
                    max_new_tokens: Some(5),
890
891
892
893
894
                    ..default_parameters()
                },
            })
            .await
        {
895
            Err(ValidationError::TopP) => (),
896
            _ => panic!("Unexpected top_p"),
897
898
        }

899
900
901
902
903
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(0.99),
904
                    max_new_tokens: Some(5),
905
906
907
908
909
                    ..default_parameters()
                },
            })
            .await
        {
910
            Ok(_) => (),
911
            _ => panic!("Unexpected top_p error"),
912
913
        }

914
915
916
917
918
        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: None,
919
                    max_new_tokens: Some(5),
920
921
922
923
924
                    ..default_parameters()
                },
            })
            .await
            .unwrap();
925
926
927
        // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
        assert_eq!(valid_request.parameters.top_p, 1.0);
    }
Nicolas Patry's avatar
Nicolas Patry committed
928
929
930
931
932
933
934
935

    #[tokio::test]
    async fn test_validation_top_n_tokens() {
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequences = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
936
        let max_total_tokens = 106;
Nicolas Patry's avatar
Nicolas Patry committed
937
        let workers = 1;
drbh's avatar
drbh committed
938
        let disable_grammar_support = true;
939
        let config = None;
Nicolas Patry's avatar
Nicolas Patry committed
940
941
942
        let validation = Validation::new(
            workers,
            tokenizer,
943
            config,
Nicolas Patry's avatar
Nicolas Patry committed
944
945
946
947
948
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
949
            disable_grammar_support,
Nicolas Patry's avatar
Nicolas Patry committed
950
951
952
953
954
955
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(5),
956
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
957
958
959
960
961
962
963
964
965
966
967
968
969
970
                    ..default_parameters()
                },
            })
            .await
        {
            Err(ValidationError::TopNTokens(4, 5)) => (),
            _ => panic!("Unexpected top_n_tokens"),
        }

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(4),
971
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
972
973
974
975
976
977
978
979
980
981
982
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(0),
983
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
984
985
986
987
988
989
990
991
992
993
994
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: None,
995
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
996
997
998
999
1000
1001
1002
1003
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        assert_eq!(valid_request.top_n_tokens, 0);
    }
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060

    static PIXEL_GIF: &str = "R0lGODdhAQABAIEAAP///wAAAAAAAAAAACwAAAAAAQABAAAIBAABBAQAOw==";

    #[tokio::test]
    async fn test_prepare_input_chunks() {
        let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();

        let tokenizer = Some(get_tokenizer().await);

        let max_best_of = 2;
        let max_stop_sequence = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
        let disable_grammar_support = true;
        let workers = 1;
        let config = Config::Paligemma(Paligemma {
            text_config: PaliTextConfig {
                num_image_tokens: 1,
            },
        });
        let validation = Validation::new(
            workers,
            tokenizer,
            Some(config),
            max_best_of,
            max_stop_sequence,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
            disable_grammar_support,
        );

        let chunks = match validation
            .tokenize(
                format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
                None,
            )
            .await
        {
            Ok(Some((_encoding, chunks))) => chunks,
            _ => panic!("Unexpected tokenization failure"),
        };

        assert!(
            chunks
                == vec![
                    Chunk::Text("test".to_string()).into(),
                    Chunk::Image(Image {
                        data: pixel_data.clone(),
                        mimetype: "image/gif".to_string()
                    })
                    .into()
                ],
            "Failed to process images",
        );
    }
1061
}