validation.rs 35.5 KB
Newer Older
1
use crate::config::Config;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
/// Payload validation logic
3
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
drbh's avatar
drbh committed
4
use crate::{GenerateParameters, GenerateRequest, GrammarType};
5
use jsonschema::{Draft, JSONSchema};
6
use rand::{thread_rng, Rng};
7
use serde_json::Value;
8
use std::io::Cursor;
drbh's avatar
drbh committed
9
use text_generation_client::{
10
11
    Chunk, GrammarType as ProtoGrammarType, Image, InputChunk, NextTokenChooserParameters,
    StoppingCriteriaParameters,
drbh's avatar
drbh committed
12
};
Olivier Dehaene's avatar
Olivier Dehaene committed
13
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
14
use tokenizers::tokenizer::Tokenizer;
15
16
17
// use tokenizers::TruncationDirection;
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat};
OlivierDehaene's avatar
OlivierDehaene committed
18
use tokio::sync::mpsc;
19
use tokio::sync::oneshot;
20
use tracing::{instrument, Span};
21
use {once_cell::sync::Lazy, regex::Regex};
Olivier Dehaene's avatar
Olivier Dehaene committed
22

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
23
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
24
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
25
pub struct Validation {
26
    /// Validation parameters
27
    max_best_of: usize,
28
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
29
    max_top_n_tokens: u32,
30
31
    max_input_length: usize,
    max_total_tokens: usize,
drbh's avatar
drbh committed
32
    disable_grammar_support: bool,
33
    /// Channel to communicate with the background tokenization task
OlivierDehaene's avatar
OlivierDehaene committed
34
    sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
35
36
37
}

impl Validation {
OlivierDehaene's avatar
OlivierDehaene committed
38
    #[allow(clippy::too_many_arguments)]
39
40
    pub(crate) fn new(
        workers: usize,
41
        tokenizer: Option<Tokenizer>,
42
        config: Option<Config>,
43
        max_best_of: usize,
44
        max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
45
        max_top_n_tokens: u32,
46
47
        max_input_length: usize,
        max_total_tokens: usize,
drbh's avatar
drbh committed
48
        disable_grammar_support: bool,
49
    ) -> Self {
50
51
        // If we have a fast tokenizer
        let sender = if let Some(tokenizer) = tokenizer {
OlivierDehaene's avatar
OlivierDehaene committed
52
53
54
            // Create round robin channel
            let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
            let mut senders = Vec::with_capacity(workers);
55
56
57
58

            // Create workers
            for _ in 0..workers {
                let tokenizer_clone = tokenizer.clone();
59
                let config_clone = config.clone();
OlivierDehaene's avatar
OlivierDehaene committed
60
61
                let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
                senders.push(tokenizer_sender);
62
63
64

                // Spawn worker
                tokio::task::spawn_blocking(move || {
65
                    tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver)
66
67
                });
            }
OlivierDehaene's avatar
OlivierDehaene committed
68
69
70
71

            // Create tokenization round robin task
            tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));

72
73
74
75
76
77
78
79
            Some(validation_sender)
        } else {
            None
        };

        Self {
            max_best_of,
            sender,
80
            max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
81
            max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
82
            max_input_length,
83
            max_total_tokens,
drbh's avatar
drbh committed
84
            disable_grammar_support,
85
86
        }
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
87

88
    #[instrument(skip(self, inputs))]
89
    pub async fn tokenize(
90
91
92
        &self,
        inputs: String,
        truncate: Option<usize>,
93
    ) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
94
95
96
97
98
99
100
101
102
103
104
105
        // If we have a fast tokenizer
        if let Some(sender) = &self.sender {
            // Create response channel
            let (response_sender, response_receiver) = oneshot::channel();
            // Send request to the background validation task
            // Unwrap is safe here
            sender
                .send(((inputs, truncate), response_sender, Span::current()))
                .unwrap();

            // Await on response channel
            // Unwrap is safe here
106
107
108
109
110
111
112
113
114
115
116
117
118
            let encoding = response_receiver.await.unwrap()?;
            Ok(Some(encoding))
        } else {
            Ok(None)
        }
    }

    #[instrument(skip(self, inputs))]
    async fn validate_input(
        &self,
        inputs: String,
        truncate: Option<usize>,
        max_new_tokens: Option<u32>,
119
    ) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
120
121
122
        // If we have a fast tokenizer
        if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
            // Create response channel
Nicolas Patry's avatar
Nicolas Patry committed
123
124
125
126
127
            let input_length = if let Some(truncate) = truncate {
                std::cmp::min(encoding.len(), truncate)
            } else {
                encoding.len()
            };
128
129

            // Get total tokens
130
131
132
133
134
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
                self.max_total_tokens.saturating_sub(input_length) as u32
            };
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            let total_tokens = input_length + max_new_tokens as usize;

            // Validate MaxTotalTokens
            if total_tokens > self.max_total_tokens {
                return Err(ValidationError::MaxTotalTokens(
                    self.max_total_tokens,
                    input_length,
                    max_new_tokens,
                ));
            }

            // Validate InputLength
            if input_length > self.max_input_length {
                return Err(ValidationError::InputLength(
                    self.max_input_length,
                    input_length,
                ));
            }

            metrics::histogram!("tgi_request_input_length", input_length as f64);
155
            Ok((inputs, input_length, max_new_tokens))
156
157
158
159
160
161
        }
        // Return inputs without validation
        else {
            // In this case, we don't know the real length in tokens of the inputs
            // However, the inputs will be truncated by the python servers
            // We make sure that truncate + max_new_tokens <= self.max_total_tokens
162
163
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
OlivierDehaene's avatar
OlivierDehaene committed
164
165
            } else if let Some(truncate) = truncate {
                self.max_total_tokens.saturating_sub(truncate) as u32
166
            } else {
OlivierDehaene's avatar
OlivierDehaene committed
167
                return Err(ValidationError::UnsetMaxNewTokens);
168
            };
169
            let mut input_length = truncate.unwrap_or(self.max_input_length);
170

171
172
            // We don't have a tokenizer, therefore we have no idea how long is the query, let
            // them through and hope for the best.
173
            // Validate MaxNewTokens
174
            if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
175
176
177
178
179
                input_length = input_length.saturating_sub(max_new_tokens as usize);
                // return Err(ValidationError::MaxNewTokens(
                //     self.max_total_tokens - self.max_input_length,
                //     max_new_tokens,
                // ));
180
181
            }

182
183
184
185
186
            Ok((
                vec![Chunk::Text(inputs).into()],
                input_length,
                max_new_tokens,
            ))
Olivier Dehaene's avatar
Olivier Dehaene committed
187
188
189
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
190
    /// Validate a payload and get the number of tokens in the input
191
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
192
193
194
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
195
    ) -> Result<ValidGenerateRequest, ValidationError> {
196
197
198
199
        let GenerateParameters {
            best_of,
            temperature,
            repetition_penalty,
200
            frequency_penalty,
201
202
203
204
205
206
207
208
209
            top_k,
            top_p,
            typical_p,
            do_sample,
            max_new_tokens,
            stop: stop_sequences,
            truncate,
            seed,
            watermark,
210
            decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
211
            top_n_tokens,
drbh's avatar
drbh committed
212
            grammar,
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            ..
        } = request.parameters;

        // sampling must be true when best_of > 1
        let best_of = best_of.unwrap_or(1);
        let sampling = do_sample
            || temperature.is_some()
            || top_k.is_some()
            || top_p.is_some()
            || typical_p.is_some();

        if best_of > 1 && !sampling {
            return Err(BestOfSampling);
        }

        let temperature = temperature.unwrap_or(1.0);
        if temperature <= 0.0 {
            return Err(ValidationError::Temperature);
        }

        let repetition_penalty = repetition_penalty.unwrap_or(1.0);
        if repetition_penalty <= 0.0 {
            return Err(ValidationError::RepetitionPenalty);
        }

238
239
240
241
242
        let frequency_penalty = frequency_penalty.unwrap_or(0.0);
        if !(-2.0..=2.0).contains(&frequency_penalty) {
            return Err(ValidationError::FrequencyPenalty);
        }

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        // Different because the proto default value is not a valid value
        // for the user
        let top_p = top_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TopP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let typical_p = typical_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TypicalP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let top_k: u32 = top_k
            .map(|value| {
                if value <= 0 {
                    return Err(ValidationError::TopK);
                }
                Ok(value as u32)
            })
            .unwrap_or(Ok(0))?;

272
        if max_new_tokens == Some(0) {
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
            return Err(ValidationError::NegativeMaxNewTokens);
        }

        if stop_sequences.len() > self.max_stop_sequences {
            return Err(ValidationError::StopSequence(
                self.max_stop_sequences,
                stop_sequences.len(),
            ));
        }

        // If seed is None, assign a random one
        let seed = match seed {
            None => thread_rng().gen(),
            Some(seed) => {
                if best_of > 1 {
                    return Err(BestOfSeed);
                }
                seed
            }
        };

Nicolas Patry's avatar
Nicolas Patry committed
294
295
296
297
298
299
300
301
302
        let top_n_tokens = top_n_tokens
            .map(|value| {
                if value > self.max_top_n_tokens {
                    return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
                }
                Ok(value)
            })
            .unwrap_or(Ok(0))?;

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        // Check if inputs is empty
        if request.inputs.is_empty() {
            return Err(EmptyInput);
        }

        // Check if truncate is strictly positive and less than max_input_length
        let truncate = truncate
            .map(|value| {
                if value == 0 || value > self.max_input_length {
                    return Err(ValidationError::Truncate(self.max_input_length, value));
                }
                Ok(Some(value))
            })
            .unwrap_or(Ok(None))?;

        // Validate inputs
319
        let (inputs, input_length, max_new_tokens) = self
320
321
322
            .validate_input(request.inputs, truncate, max_new_tokens)
            .await?;

drbh's avatar
drbh committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
        // NOTE: this is currently difficult because we need the tokenizer in Python to build
        // the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which
        // may be slow and memory intensive. Best case is to have a Rust implementation of the FSM
        // compiler and use that to build the FSM here.

        // Validate grammar and unpack the grammar and type for the proto message
        let (grammar, grammar_type) = match grammar {
            Some(grammar) => {
                // Ensure that grammar is not set if it's not supported
                if self.disable_grammar_support {
                    return Err(ValidationError::Grammar);
                }
                match grammar {
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
                    GrammarType::Json(json) => {
                        let json = match json {
                            // if value is a string, we need to parse it again to make sure its
                            // a valid json
                            Value::String(s) => serde_json::from_str(&s)
                                .map_err(|e| ValidationError::InvalidGrammar(e.to_string())),
                            Value::Object(_) => Ok(json),
                            _ => Err(ValidationError::Grammar),
                        }?;

                        // Check if the json is a valid JSONSchema
                        JSONSchema::options()
                            .with_draft(Draft::Draft202012)
                            .compile(&json)
                            .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;

                        (
                            // Serialize json to string
                            serde_json::to_string(&json)
                                .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
                            ProtoGrammarType::Json.into(),
                        )
                    }
drbh's avatar
drbh committed
360
361
362
363
364
365
                    GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
                }
            }
            None => (String::new(), ProtoGrammarType::None.into()),
        };

366
367
368
        let parameters = NextTokenChooserParameters {
            temperature,
            repetition_penalty,
369
            frequency_penalty,
370
371
372
373
374
375
            top_k,
            top_p,
            typical_p,
            do_sample,
            seed,
            watermark,
drbh's avatar
drbh committed
376
377
            grammar,
            grammar_type,
378
379
380
381
382
383
384
385
386
387
388
        };
        let stopping_parameters = StoppingCriteriaParameters {
            max_new_tokens,
            stop_sequences,
            ignore_eos_token: false,
        };

        metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

        Ok(ValidGenerateRequest {
            inputs,
389
            decoder_input_details,
390
            input_length: input_length as u32,
391
392
393
            truncate: truncate.unwrap_or(self.max_input_length) as u32,
            parameters,
            stopping_parameters,
Nicolas Patry's avatar
Nicolas Patry committed
394
            top_n_tokens,
395
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
396
    }
397
398
399
400
401
402
403
404
405
406
407
408
409
410

    /// Validate the best_of parameter
    #[instrument(skip_all)]
    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
        if self.max_best_of == 1 && best_of != 1 {
            return Err(ValidationError::BestOfDisabled);
        }

        if best_of > self.max_best_of {
            return Err(ValidationError::BestOf(self.max_best_of, best_of));
        }

        Ok(best_of)
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
411
412
}

OlivierDehaene's avatar
OlivierDehaene committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
/// Round robin tokenization task
async fn round_robin_task(
    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
    senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
) {
    loop {
        for sender in &senders {
            match receiver.recv().await {
                None => return,
                Some(request) => sender.send(request).unwrap(),
            };
        }
    }
}

428
/// Start tokenization workers
429
430
431
432
433
fn tokenizer_worker(
    tokenizer: Tokenizer,
    config: Option<Config>,
    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
434
    // Loop over requests
OlivierDehaene's avatar
OlivierDehaene committed
435
    while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
436
437
        parent_span.in_scope(|| {
            response_tx
438
                .send(prepare_input(inputs, truncate, &tokenizer, &config))
439
440
                .unwrap_or(())
        })
441
442
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
443

444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
    match mimetype {
        "image/png" => Some(ImageFormat::Png),
        "image/jpeg" => Some(ImageFormat::Jpeg),
        "image/jpg" => Some(ImageFormat::Jpeg),
        "image/gif" => Some(ImageFormat::Gif),
        "image/webp" => Some(ImageFormat::WebP),
        "image/tiff" => Some(ImageFormat::Tiff),
        // "image/pnm"=>Some(ImageFormat::Pnm),
        // "image/tga"=>Some(ImageFormat::Tga),
        // "image/dds"=>Some(ImageFormat::Dds),
        // "image/bmp"=>Some(ImageFormat::Bmp),
        // "image/ico"=>Some(ImageFormat::Ico),
        // "image/x-exr"=>Some(ImageFormat::OpenExr),
        _ => None,
    }
}
fn format_to_mimetype(format: ImageFormat) -> String {
    match format {
        ImageFormat::Png => "image/png",
        ImageFormat::Jpeg => "image/jpeg",
        ImageFormat::Gif => "image/gif",
        ImageFormat::WebP => "image/webp",
        ImageFormat::Tiff => "image/tiff",
        _ => "application/octet-stream",
    }
    .to_string()
}

473
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
474
475
476
477
478
479
480
481
482
483
    if input.starts_with("![](http://") || input.starts_with("![](https://") {
        let url = &input["![](".len()..input.len() - 1];
        let data = reqwest::blocking::get(url)?.bytes()?;

        let format = image::guess_format(&data)?;
        // TODO Remove this clone
        let img = ImageReader::with_format(Cursor::new(data.clone()), format).decode()?;
        let height: usize = img.height().try_into()?;
        let width: usize = img.width().try_into()?;
        let mimetype = format_to_mimetype(format);
484
        Ok((data.to_vec(), mimetype, height, width))
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    } else if input.starts_with("![](data:") {
        // Remove ![](....)
        let content = &input["![](data:".len()..input.len() - 1];
        let tokens: Vec<_> = content.split(';').collect();
        if tokens.len() != 2 {
            return Err(ValidationError::InvalidImageContent(content.to_string()));
        }
        let mimetype = tokens[0];
        let content = tokens[1];

        if !content.starts_with("base64,") {
            return Err(ValidationError::InvalidImageContent(content.to_string()));
        }

        let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
        let img = if let Some(format) = format_from_mimetype(mimetype) {
501
            ImageReader::with_format(Cursor::new(&data), format).decode()?
502
        } else {
503
            ImageReader::new(Cursor::new(&data))
504
505
506
507
508
509
510
                .with_guessed_format()
                .map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
                .decode()?
        };

        let height: usize = img.height().try_into()?;
        let width: usize = img.width().try_into()?;
511
        Ok((data, mimetype.to_string(), height, width))
512
513
514
515
516
    } else {
        Err(ValidationError::InvalidImageContent(input.to_string()))
    }
}

517
518
/// Get input length and optionally truncate it
fn prepare_input(
519
    inputs: String,
520
    _truncate: Option<usize>,
521
    tokenizer: &Tokenizer,
522
    config: &Option<Config>,
523
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
524
    static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
525
    let (tokenizer_query, input_chunks) = match config {
526
        Some(Config::LlavaNext(config)) => {
527
            let mut input_chunks = Vec::new();
528
529
530
531
            let mut tokenizer_query = String::with_capacity(inputs.len());
            let mut start = 0;
            for chunk in RE.find_iter(&inputs) {
                let chunk_start = chunk.start();
drbh's avatar
drbh committed
532
533
                let chunk_end = chunk.end();
                if chunk_start != start {
534
                    input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
drbh's avatar
drbh committed
535
536
                    tokenizer_query.push_str(&inputs[start..chunk_start]);
                }
537
                let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
drbh's avatar
drbh committed
538
                let slots = config.get_number_of_features(height, width);
539
                input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
drbh's avatar
drbh committed
540
541
542
                tokenizer_query.push_str(&"<image>".repeat(slots));
                start = chunk_end;
            }
543
544
            if start != inputs.len() {
                input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
drbh's avatar
drbh committed
545
546
                tokenizer_query.push_str(&inputs[start..]);
            }
547
            (tokenizer_query, input_chunks)
drbh's avatar
drbh committed
548
549
        }
        Some(Config::Paligemma(config)) => {
550
            let mut input_chunks = Vec::new();
drbh's avatar
drbh committed
551
552
553
554
            let mut tokenizer_query = String::with_capacity(inputs.len());
            let mut start = 0;
            for chunk in RE.find_iter(&inputs) {
                let chunk_start = chunk.start();
555
556
                let chunk_end = chunk.end();
                if chunk_start != start {
557
                    input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
558
559
                    tokenizer_query.push_str(&inputs[start..chunk_start]);
                }
560
                let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
561
                let slots = config.get_number_of_features(height, width);
562
                input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
563
564
565
                tokenizer_query.push_str(&"<image>".repeat(slots));
                start = chunk_end;
            }
566
567
            if start != inputs.len() {
                input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
568
569
                tokenizer_query.push_str(&inputs[start..]);
            }
570
            (tokenizer_query, input_chunks)
571
        }
Nicolas Patry's avatar
Nicolas Patry committed
572
        Some(Config::Idefics2(config)) => {
573
            let mut input_chunks = Vec::new();
Nicolas Patry's avatar
Nicolas Patry committed
574
575
576
577
578
579
            let mut tokenizer_query = String::with_capacity(inputs.len());
            let mut start = 0;
            for chunk in RE.find_iter(&inputs) {
                let chunk_start = chunk.start();
                let chunk_end = chunk.end();
                if chunk_start != start {
580
                    input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
Nicolas Patry's avatar
Nicolas Patry committed
581
582
                    tokenizer_query.push_str(&inputs[start..chunk_start]);
                }
583
                let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
Nicolas Patry's avatar
Nicolas Patry committed
584
585
586
587
588
                let slots = config.get_number_of_features(height, width);
                tokenizer_query.push_str("<fake_token_around_image>");
                tokenizer_query.push_str(&"<image>".repeat(slots));
                tokenizer_query.push_str("<fake_token_around_image>");

589
                input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
Nicolas Patry's avatar
Nicolas Patry committed
590
591
                start = chunk_end;
            }
592
593
            if start != inputs.len() {
                input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
Nicolas Patry's avatar
Nicolas Patry committed
594
595
                tokenizer_query.push_str(&inputs[start..]);
            }
596
            (tokenizer_query, input_chunks)
Nicolas Patry's avatar
Nicolas Patry committed
597
598
        }
        Some(Config::Idefics) => {
599
            let mut input_chunks = Vec::new();
Nicolas Patry's avatar
Nicolas Patry committed
600
601
602
603
604
605
            let mut tokenizer_query = String::with_capacity(inputs.len());
            let mut start = 0;
            for chunk in RE.find_iter(&inputs) {
                let chunk_start = chunk.start();
                let chunk_end = chunk.end();
                if chunk_start != start {
606
                    input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
Nicolas Patry's avatar
Nicolas Patry committed
607
608
                    tokenizer_query.push_str(&inputs[start..chunk_start]);
                }
609
610
                let (data, mimetype, _height, _width) =
                    fetch_image(&inputs[chunk_start..chunk_end])?;
Nicolas Patry's avatar
Nicolas Patry committed
611
612
                let slots = 1;
                tokenizer_query.push_str(&"<image>".repeat(slots));
613
                input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
Nicolas Patry's avatar
Nicolas Patry committed
614
615
                start = chunk_end;
            }
616
617
            if start != inputs.len() {
                input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
Nicolas Patry's avatar
Nicolas Patry committed
618
619
                tokenizer_query.push_str(&inputs[start..]);
            }
620
            (tokenizer_query, input_chunks)
Nicolas Patry's avatar
Nicolas Patry committed
621
        }
622
        _ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
623
    };
624

625
    // Get the number of tokens in the input
626
627
    let encoding = tokenizer
        .encode(tokenizer_query, true)
628
629
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

630
    Ok((encoding, input_chunks))
Olivier Dehaene's avatar
Olivier Dehaene committed
631
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
632

633
634
type TokenizerRequest = (
    (String, Option<usize>),
635
    oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
636
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
637
638
);

639
#[derive(Debug, Clone)]
640
pub(crate) struct ValidGenerateRequest {
641
    pub inputs: Vec<InputChunk>,
642
    pub input_length: u32,
643
    pub truncate: u32,
644
    pub decoder_input_details: bool,
645
646
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
Nicolas Patry's avatar
Nicolas Patry committed
647
    pub top_n_tokens: u32,
648
649
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
650
651
#[derive(Error, Debug)]
pub enum ValidationError {
652
653
654
655
656
657
658
659
660
661
    #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
    BestOf(usize, usize),
    #[error("`best_of` != 1 is not allowed for this endpoint")]
    BestOfDisabled,
    #[error("you must use sampling when `best_of` is > 1")]
    BestOfSampling,
    #[error("`seed` must not be set when `best_of` > 1")]
    BestOfSeed,
    #[error("`best_of` != 1 is not supported when streaming tokens")]
    BestOfStream,
Nicolas Patry's avatar
Nicolas Patry committed
662
663
664
665
    #[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
    TopNTokens(u32, u32),
    #[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
    TopNTokensDisabled,
666
667
    #[error("`decoder_input_details` == true is not supported when streaming tokens")]
    PrefillDetailsStream,
668
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
669
    Temperature,
670
    #[error("`repetition_penalty` must be strictly positive")]
671
    RepetitionPenalty,
672
673
    #[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
    FrequencyPenalty,
674
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
675
    TopP,
676
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
677
    TopK,
678
679
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
680
681
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
682
683
    #[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
    UnsetMaxNewTokens,
684
    #[error("`max_new_tokens` must be strictly positive")]
685
686
687
    NegativeMaxNewTokens,
    #[error("`max_new_tokens` must be <= {0}. Given: {1}")]
    MaxNewTokens(usize, u32),
688
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
689
    MaxTotalTokens(usize, usize, u32),
690
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
691
    InputLength(usize, usize),
692
    #[error("`inputs` cannot be empty")]
693
    EmptyInput,
694
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
695
    StopSequence(usize, usize),
696
697
    #[error("tokenizer error {0}")]
    Tokenizer(String),
drbh's avatar
drbh committed
698
699
    #[error("grammar is not supported")]
    Grammar,
700
701
    #[error("grammar is not valid: {0}")]
    InvalidGrammar(String),
702
703
704
705
706
707
708
709
710
711
    #[error("base64 encoding is invalid: {0}")]
    InvalidBase64(#[from] base64::DecodeError),
    #[error("invalid image: {0}")]
    InvalidImage(#[from] image::ImageError),
    #[error("invalid integer: {0}")]
    InvalidInt(#[from] core::num::TryFromIntError),
    #[error("invalid image content: {0}")]
    InvalidImageContent(String),
    #[error("Could not fetch image: {0}")]
    FailedFetchImage(#[from] reqwest::Error),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
712
}
713
714

#[cfg(test)]
715
mod tests {
716
    use super::*;
717
    use crate::config::{PaliTextConfig, Paligemma};
718
719
    use crate::default_parameters;
    use crate::tests::get_tokenizer;
720
721

    #[tokio::test]
722
    async fn test_validation_max_new_tokens() {
723
724
725
        let tokenizer = None;
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
726
727
728
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
729
        let workers = 1;
drbh's avatar
drbh committed
730
        let disable_grammar_support = true;
731
        let config = None;
732
733
734
        let validation = Validation::new(
            workers,
            tokenizer,
735
            config,
736
737
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
738
            max_top_n_tokens,
739
740
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
741
            disable_grammar_support,
742
        );
743
744

        let max_new_tokens = 10;
745
        match validation
746
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
747
748
            .await
        {
749
750
751
            // Err(ValidationError::MaxNewTokens(1, 10)) => (),
            Ok((_s, 0, 10)) => (),
            r => panic!("Unexpected not max new tokens: {r:?}"),
752
753
754
755
        }
    }

    #[tokio::test]
756
    async fn test_validation_input_length() {
757
758
759
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
760
761
762
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
drbh's avatar
drbh committed
763
        let disable_grammar_support = true;
764
        let workers = 1;
765
        let config = None;
766
767
768
        let validation = Validation::new(
            workers,
            tokenizer,
769
            config,
770
771
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
772
            max_top_n_tokens,
773
774
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
775
            disable_grammar_support,
776
        );
777
778

        let max_new_tokens = 10;
779
        match validation
780
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
781
782
            .await
        {
Nicolas Patry's avatar
Nicolas Patry committed
783
            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
784
            _ => panic!("Unexpected not max new tokens"),
785
786
        }
    }
787
788

    #[tokio::test]
789
    async fn test_validation_best_of_sampling() {
790
791
792
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
793
794
795
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
796
        let workers = 1;
drbh's avatar
drbh committed
797
        let disable_grammar_support = true;
798
        let config = None;
799
800
801
        let validation = Validation::new(
            workers,
            tokenizer,
802
            config,
803
804
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
805
            max_top_n_tokens,
806
807
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
808
            disable_grammar_support,
809
810
811
812
813
814
815
816
817
818
819
820
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    best_of: Some(2),
                    do_sample: false,
                    ..default_parameters()
                },
            })
            .await
        {
821
            Err(ValidationError::BestOfSampling) => (),
822
            _ => panic!("Unexpected not best of sampling"),
823
824
825
826
        }
    }

    #[tokio::test]
827
    async fn test_validation_top_p() {
828
829
830
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
831
832
        let max_top_n_tokens = 4;
        let max_input_length = 5;
833
        let max_total_tokens = 106;
834
        let workers = 1;
drbh's avatar
drbh committed
835
        let disable_grammar_support = true;
836
        let config = None;
837
838
839
        let validation = Validation::new(
            workers,
            tokenizer,
840
            config,
841
842
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
843
            max_top_n_tokens,
844
845
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
846
            disable_grammar_support,
847
848
849
850
851
852
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(1.0),
853
                    max_new_tokens: Some(5),
854
855
856
857
858
                    ..default_parameters()
                },
            })
            .await
        {
859
            Err(ValidationError::TopP) => (),
860
            _ => panic!("Unexpected top_p"),
861
862
        }

863
864
865
866
867
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(0.99),
868
                    max_new_tokens: Some(5),
869
870
871
872
873
                    ..default_parameters()
                },
            })
            .await
        {
874
            Ok(_) => (),
875
            _ => panic!("Unexpected top_p error"),
876
877
        }

878
879
880
881
882
        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: None,
883
                    max_new_tokens: Some(5),
884
885
886
887
888
                    ..default_parameters()
                },
            })
            .await
            .unwrap();
889
890
891
        // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
        assert_eq!(valid_request.parameters.top_p, 1.0);
    }
Nicolas Patry's avatar
Nicolas Patry committed
892
893
894
895
896
897
898
899

    #[tokio::test]
    async fn test_validation_top_n_tokens() {
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequences = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
900
        let max_total_tokens = 106;
Nicolas Patry's avatar
Nicolas Patry committed
901
        let workers = 1;
drbh's avatar
drbh committed
902
        let disable_grammar_support = true;
903
        let config = None;
Nicolas Patry's avatar
Nicolas Patry committed
904
905
906
        let validation = Validation::new(
            workers,
            tokenizer,
907
            config,
Nicolas Patry's avatar
Nicolas Patry committed
908
909
910
911
912
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
913
            disable_grammar_support,
Nicolas Patry's avatar
Nicolas Patry committed
914
915
916
917
918
919
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(5),
920
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
921
922
923
924
925
926
927
928
929
930
931
932
933
934
                    ..default_parameters()
                },
            })
            .await
        {
            Err(ValidationError::TopNTokens(4, 5)) => (),
            _ => panic!("Unexpected top_n_tokens"),
        }

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(4),
935
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
936
937
938
939
940
941
942
943
944
945
946
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(0),
947
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
948
949
950
951
952
953
954
955
956
957
958
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: None,
959
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
960
961
962
963
964
965
966
967
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        assert_eq!(valid_request.top_n_tokens, 0);
    }
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024

    static PIXEL_GIF: &str = "R0lGODdhAQABAIEAAP///wAAAAAAAAAAACwAAAAAAQABAAAIBAABBAQAOw==";

    #[tokio::test]
    async fn test_prepare_input_chunks() {
        let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();

        let tokenizer = Some(get_tokenizer().await);

        let max_best_of = 2;
        let max_stop_sequence = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
        let disable_grammar_support = true;
        let workers = 1;
        let config = Config::Paligemma(Paligemma {
            text_config: PaliTextConfig {
                num_image_tokens: 1,
            },
        });
        let validation = Validation::new(
            workers,
            tokenizer,
            Some(config),
            max_best_of,
            max_stop_sequence,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
            disable_grammar_support,
        );

        let chunks = match validation
            .tokenize(
                format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
                None,
            )
            .await
        {
            Ok(Some((_encoding, chunks))) => chunks,
            _ => panic!("Unexpected tokenization failure"),
        };

        assert!(
            chunks
                == vec![
                    Chunk::Text("test".to_string()).into(),
                    Chunk::Image(Image {
                        data: pixel_data.clone(),
                        mimetype: "image/gif".to_string()
                    })
                    .into()
                ],
            "Failed to process images",
        );
    }
1025
}