validation.rs 29.7 KB
Newer Older
1
use crate::config::Config;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
/// Payload validation logic
3
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
drbh's avatar
drbh committed
4
use crate::{GenerateParameters, GenerateRequest, GrammarType};
5
use jsonschema::{Draft, JSONSchema};
6
use rand::{thread_rng, Rng};
7
use serde_json::Value;
8
use std::io::Cursor;
drbh's avatar
drbh committed
9
10
11
use text_generation_client::{
    GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
Olivier Dehaene's avatar
Olivier Dehaene committed
12
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
13
use tokenizers::tokenizer::Tokenizer;
14
15
16
// use tokenizers::TruncationDirection;
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat};
OlivierDehaene's avatar
OlivierDehaene committed
17
use tokio::sync::mpsc;
18
use tokio::sync::oneshot;
19
use tracing::{instrument, Span};
20
use {once_cell::sync::Lazy, regex::Regex};
Olivier Dehaene's avatar
Olivier Dehaene committed
21

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
22
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
23
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
24
pub struct Validation {
25
    /// Validation parameters
26
    max_best_of: usize,
27
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
28
    max_top_n_tokens: u32,
29
30
    max_input_length: usize,
    max_total_tokens: usize,
drbh's avatar
drbh committed
31
    disable_grammar_support: bool,
32
    /// Channel to communicate with the background tokenization task
OlivierDehaene's avatar
OlivierDehaene committed
33
    sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
34
35
36
}

impl Validation {
OlivierDehaene's avatar
OlivierDehaene committed
37
    #[allow(clippy::too_many_arguments)]
38
39
    pub(crate) fn new(
        workers: usize,
40
        tokenizer: Option<Tokenizer>,
41
        config: Option<Config>,
42
        max_best_of: usize,
43
        max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
44
        max_top_n_tokens: u32,
45
46
        max_input_length: usize,
        max_total_tokens: usize,
drbh's avatar
drbh committed
47
        disable_grammar_support: bool,
48
    ) -> Self {
49
50
        // If we have a fast tokenizer
        let sender = if let Some(tokenizer) = tokenizer {
OlivierDehaene's avatar
OlivierDehaene committed
51
52
53
            // Create round robin channel
            let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
            let mut senders = Vec::with_capacity(workers);
54
55
56
57

            // Create workers
            for _ in 0..workers {
                let tokenizer_clone = tokenizer.clone();
58
                let config_clone = config.clone();
OlivierDehaene's avatar
OlivierDehaene committed
59
60
                let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
                senders.push(tokenizer_sender);
61
62
63

                // Spawn worker
                tokio::task::spawn_blocking(move || {
64
                    tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver)
65
66
                });
            }
OlivierDehaene's avatar
OlivierDehaene committed
67
68
69
70

            // Create tokenization round robin task
            tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));

71
72
73
74
75
76
77
78
            Some(validation_sender)
        } else {
            None
        };

        Self {
            max_best_of,
            sender,
79
            max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
80
            max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
81
            max_input_length,
82
            max_total_tokens,
drbh's avatar
drbh committed
83
            disable_grammar_support,
84
85
        }
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
86

87
    #[instrument(skip(self, inputs))]
88
    pub async fn tokenize(
89
90
91
        &self,
        inputs: String,
        truncate: Option<usize>,
92
    ) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
93
94
95
96
97
98
99
100
101
102
103
104
        // If we have a fast tokenizer
        if let Some(sender) = &self.sender {
            // Create response channel
            let (response_sender, response_receiver) = oneshot::channel();
            // Send request to the background validation task
            // Unwrap is safe here
            sender
                .send(((inputs, truncate), response_sender, Span::current()))
                .unwrap();

            // Await on response channel
            // Unwrap is safe here
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            let encoding = response_receiver.await.unwrap()?;
            Ok(Some(encoding))
        } else {
            Ok(None)
        }
    }

    #[instrument(skip(self, inputs))]
    async fn validate_input(
        &self,
        inputs: String,
        truncate: Option<usize>,
        max_new_tokens: Option<u32>,
    ) -> Result<(String, usize, u32), ValidationError> {
        // If we have a fast tokenizer
        if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
            // Create response channel
            let input_length = encoding.len();
123
124

            // Get total tokens
125
126
127
128
129
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
                self.max_total_tokens.saturating_sub(input_length) as u32
            };
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
            let total_tokens = input_length + max_new_tokens as usize;

            // Validate MaxTotalTokens
            if total_tokens > self.max_total_tokens {
                return Err(ValidationError::MaxTotalTokens(
                    self.max_total_tokens,
                    input_length,
                    max_new_tokens,
                ));
            }

            // Validate InputLength
            if input_length > self.max_input_length {
                return Err(ValidationError::InputLength(
                    self.max_input_length,
                    input_length,
                ));
            }

            metrics::histogram!("tgi_request_input_length", input_length as f64);
150
            Ok((inputs, input_length, max_new_tokens))
151
152
153
154
155
156
        }
        // Return inputs without validation
        else {
            // In this case, we don't know the real length in tokens of the inputs
            // However, the inputs will be truncated by the python servers
            // We make sure that truncate + max_new_tokens <= self.max_total_tokens
157
158
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
OlivierDehaene's avatar
OlivierDehaene committed
159
160
            } else if let Some(truncate) = truncate {
                self.max_total_tokens.saturating_sub(truncate) as u32
161
            } else {
OlivierDehaene's avatar
OlivierDehaene committed
162
                return Err(ValidationError::UnsetMaxNewTokens);
163
            };
164
            let input_length = truncate.unwrap_or(self.max_input_length);
165
166

            // Validate MaxNewTokens
167
            if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
168
169
170
171
172
173
                return Err(ValidationError::MaxNewTokens(
                    self.max_total_tokens - self.max_input_length,
                    max_new_tokens,
                ));
            }

174
            Ok((inputs, input_length, max_new_tokens))
Olivier Dehaene's avatar
Olivier Dehaene committed
175
176
177
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
178
    /// Validate a payload and get the number of tokens in the input
179
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
180
181
182
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
183
    ) -> Result<ValidGenerateRequest, ValidationError> {
184
185
186
187
        let GenerateParameters {
            best_of,
            temperature,
            repetition_penalty,
188
            frequency_penalty,
189
190
191
192
193
194
195
196
197
            top_k,
            top_p,
            typical_p,
            do_sample,
            max_new_tokens,
            stop: stop_sequences,
            truncate,
            seed,
            watermark,
198
            decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
199
            top_n_tokens,
drbh's avatar
drbh committed
200
            grammar,
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
            ..
        } = request.parameters;

        // sampling must be true when best_of > 1
        let best_of = best_of.unwrap_or(1);
        let sampling = do_sample
            || temperature.is_some()
            || top_k.is_some()
            || top_p.is_some()
            || typical_p.is_some();

        if best_of > 1 && !sampling {
            return Err(BestOfSampling);
        }

        let temperature = temperature.unwrap_or(1.0);
        if temperature <= 0.0 {
            return Err(ValidationError::Temperature);
        }

        let repetition_penalty = repetition_penalty.unwrap_or(1.0);
        if repetition_penalty <= 0.0 {
            return Err(ValidationError::RepetitionPenalty);
        }

226
227
228
229
230
        let frequency_penalty = frequency_penalty.unwrap_or(0.0);
        if !(-2.0..=2.0).contains(&frequency_penalty) {
            return Err(ValidationError::FrequencyPenalty);
        }

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        // Different because the proto default value is not a valid value
        // for the user
        let top_p = top_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TopP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let typical_p = typical_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TypicalP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let top_k: u32 = top_k
            .map(|value| {
                if value <= 0 {
                    return Err(ValidationError::TopK);
                }
                Ok(value as u32)
            })
            .unwrap_or(Ok(0))?;

260
        if max_new_tokens == Some(0) {
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
            return Err(ValidationError::NegativeMaxNewTokens);
        }

        if stop_sequences.len() > self.max_stop_sequences {
            return Err(ValidationError::StopSequence(
                self.max_stop_sequences,
                stop_sequences.len(),
            ));
        }

        // If seed is None, assign a random one
        let seed = match seed {
            None => thread_rng().gen(),
            Some(seed) => {
                if best_of > 1 {
                    return Err(BestOfSeed);
                }
                seed
            }
        };

Nicolas Patry's avatar
Nicolas Patry committed
282
283
284
285
286
287
288
289
290
        let top_n_tokens = top_n_tokens
            .map(|value| {
                if value > self.max_top_n_tokens {
                    return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
                }
                Ok(value)
            })
            .unwrap_or(Ok(0))?;

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        // Check if inputs is empty
        if request.inputs.is_empty() {
            return Err(EmptyInput);
        }

        // Check if truncate is strictly positive and less than max_input_length
        let truncate = truncate
            .map(|value| {
                if value == 0 || value > self.max_input_length {
                    return Err(ValidationError::Truncate(self.max_input_length, value));
                }
                Ok(Some(value))
            })
            .unwrap_or(Ok(None))?;

        // Validate inputs
307
        let (inputs, input_length, max_new_tokens) = self
308
309
310
            .validate_input(request.inputs, truncate, max_new_tokens)
            .await?;

drbh's avatar
drbh committed
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
        // NOTE: this is currently difficult because we need the tokenizer in Python to build
        // the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which
        // may be slow and memory intensive. Best case is to have a Rust implementation of the FSM
        // compiler and use that to build the FSM here.

        // Validate grammar and unpack the grammar and type for the proto message
        let (grammar, grammar_type) = match grammar {
            Some(grammar) => {
                // Ensure that grammar is not set if it's not supported
                if self.disable_grammar_support {
                    return Err(ValidationError::Grammar);
                }
                match grammar {
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
                    GrammarType::Json(json) => {
                        let json = match json {
                            // if value is a string, we need to parse it again to make sure its
                            // a valid json
                            Value::String(s) => serde_json::from_str(&s)
                                .map_err(|e| ValidationError::InvalidGrammar(e.to_string())),
                            Value::Object(_) => Ok(json),
                            _ => Err(ValidationError::Grammar),
                        }?;

                        // Check if the json is a valid JSONSchema
                        JSONSchema::options()
                            .with_draft(Draft::Draft202012)
                            .compile(&json)
                            .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;

                        (
                            // Serialize json to string
                            serde_json::to_string(&json)
                                .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
                            ProtoGrammarType::Json.into(),
                        )
                    }
drbh's avatar
drbh committed
348
349
350
351
352
353
                    GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
                }
            }
            None => (String::new(), ProtoGrammarType::None.into()),
        };

354
355
356
        let parameters = NextTokenChooserParameters {
            temperature,
            repetition_penalty,
357
            frequency_penalty,
358
359
360
361
362
363
            top_k,
            top_p,
            typical_p,
            do_sample,
            seed,
            watermark,
drbh's avatar
drbh committed
364
365
            grammar,
            grammar_type,
366
367
368
369
370
371
372
373
374
375
376
        };
        let stopping_parameters = StoppingCriteriaParameters {
            max_new_tokens,
            stop_sequences,
            ignore_eos_token: false,
        };

        metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

        Ok(ValidGenerateRequest {
            inputs,
377
            decoder_input_details,
378
            input_length: input_length as u32,
379
380
381
            truncate: truncate.unwrap_or(self.max_input_length) as u32,
            parameters,
            stopping_parameters,
Nicolas Patry's avatar
Nicolas Patry committed
382
            top_n_tokens,
383
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
384
    }
385
386
387
388
389
390
391
392
393
394
395
396
397
398

    /// Validate the best_of parameter
    #[instrument(skip_all)]
    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
        if self.max_best_of == 1 && best_of != 1 {
            return Err(ValidationError::BestOfDisabled);
        }

        if best_of > self.max_best_of {
            return Err(ValidationError::BestOf(self.max_best_of, best_of));
        }

        Ok(best_of)
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
399
400
}

OlivierDehaene's avatar
OlivierDehaene committed
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
/// Round robin tokenization task
async fn round_robin_task(
    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
    senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
) {
    loop {
        for sender in &senders {
            match receiver.recv().await {
                None => return,
                Some(request) => sender.send(request).unwrap(),
            };
        }
    }
}

416
/// Start tokenization workers
417
418
419
420
421
fn tokenizer_worker(
    tokenizer: Tokenizer,
    config: Option<Config>,
    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
422
    // Loop over requests
OlivierDehaene's avatar
OlivierDehaene committed
423
    while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
424
425
        parent_span.in_scope(|| {
            response_tx
426
                .send(prepare_input(inputs, truncate, &tokenizer, &config))
427
428
                .unwrap_or(())
        })
429
430
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
431

432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
    match mimetype {
        "image/png" => Some(ImageFormat::Png),
        "image/jpeg" => Some(ImageFormat::Jpeg),
        "image/jpg" => Some(ImageFormat::Jpeg),
        "image/gif" => Some(ImageFormat::Gif),
        "image/webp" => Some(ImageFormat::WebP),
        "image/tiff" => Some(ImageFormat::Tiff),
        // "image/pnm"=>Some(ImageFormat::Pnm),
        // "image/tga"=>Some(ImageFormat::Tga),
        // "image/dds"=>Some(ImageFormat::Dds),
        // "image/bmp"=>Some(ImageFormat::Bmp),
        // "image/ico"=>Some(ImageFormat::Ico),
        // "image/x-exr"=>Some(ImageFormat::OpenExr),
        _ => None,
    }
}
fn format_to_mimetype(format: ImageFormat) -> String {
    match format {
        ImageFormat::Png => "image/png",
        ImageFormat::Jpeg => "image/jpeg",
        ImageFormat::Gif => "image/gif",
        ImageFormat::WebP => "image/webp",
        ImageFormat::Tiff => "image/tiff",
        _ => "application/octet-stream",
    }
    .to_string()
}

fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
    if input.starts_with("![](http://") || input.starts_with("![](https://") {
        let url = &input["![](".len()..input.len() - 1];
        let data = reqwest::blocking::get(url)?.bytes()?;

        let format = image::guess_format(&data)?;
        // TODO Remove this clone
        let img = ImageReader::with_format(Cursor::new(data.clone()), format).decode()?;
        let height: usize = img.height().try_into()?;
        let width: usize = img.width().try_into()?;
        let mimetype = format_to_mimetype(format);
        let encoded = STANDARD.encode(data);
        let data_uri = format!("![](data:{mimetype};base64,{encoded})");
        Ok((data_uri, height, width))
    } else if input.starts_with("![](data:") {
        // Remove ![](....)
        let content = &input["![](data:".len()..input.len() - 1];
        let tokens: Vec<_> = content.split(';').collect();
        if tokens.len() != 2 {
            return Err(ValidationError::InvalidImageContent(content.to_string()));
        }
        let mimetype = tokens[0];
        let content = tokens[1];

        if !content.starts_with("base64,") {
            return Err(ValidationError::InvalidImageContent(content.to_string()));
        }

        let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
        let img = if let Some(format) = format_from_mimetype(mimetype) {
            ImageReader::with_format(Cursor::new(data), format).decode()?
        } else {
            ImageReader::new(Cursor::new(data))
                .with_guessed_format()
                .map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
                .decode()?
        };

        let height: usize = img.height().try_into()?;
        let width: usize = img.width().try_into()?;
        Ok((input.to_string(), height, width))
    } else {
        Err(ValidationError::InvalidImageContent(input.to_string()))
    }
}

507
508
/// Get input length and optionally truncate it
fn prepare_input(
509
    mut inputs: String,
510
    _truncate: Option<usize>,
511
    tokenizer: &Tokenizer,
512
    config: &Option<Config>,
513
) -> Result<(tokenizers::Encoding, String), ValidationError> {
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
    let tokenizer_query = match config {
        Some(Config::LlavaNext(config)) => {
            let mut modified_inputs = String::with_capacity(inputs.len());
            let mut tokenizer_query = String::with_capacity(inputs.len());
            let mut start = 0;
            for chunk in RE.find_iter(&inputs) {
                let chunk_start = chunk.start();
                let chunk_end = chunk.end();
                if chunk_start != start {
                    modified_inputs.push_str(&inputs[start..chunk_start]);
                    tokenizer_query.push_str(&inputs[start..chunk_start]);
                }
                let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
                let slots = config.get_number_of_features(height, width);
                tokenizer_query.push_str(&"<image>".repeat(slots));
                modified_inputs.push_str(&image_uri);
                start = chunk_end;
            }
            if start != inputs.len() - 1 {
                modified_inputs.push_str(&inputs[start..]);
                tokenizer_query.push_str(&inputs[start..]);
            }
            inputs = modified_inputs;
            tokenizer_query
        }
        Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(),
        _ => inputs.clone(),
542
    };
543

544
    // Get the number of tokens in the input
545
546
    let encoding = tokenizer
        .encode(tokenizer_query, true)
547
548
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

549
    Ok((encoding, inputs))
Olivier Dehaene's avatar
Olivier Dehaene committed
550
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
551

552
553
type TokenizerRequest = (
    (String, Option<usize>),
554
    oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
555
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
556
557
);

558
#[derive(Debug, Clone)]
559
560
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
561
    pub input_length: u32,
562
    pub truncate: u32,
563
    pub decoder_input_details: bool,
564
565
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
Nicolas Patry's avatar
Nicolas Patry committed
566
    pub top_n_tokens: u32,
567
568
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
569
570
#[derive(Error, Debug)]
pub enum ValidationError {
571
572
573
574
575
576
577
578
579
580
    #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
    BestOf(usize, usize),
    #[error("`best_of` != 1 is not allowed for this endpoint")]
    BestOfDisabled,
    #[error("you must use sampling when `best_of` is > 1")]
    BestOfSampling,
    #[error("`seed` must not be set when `best_of` > 1")]
    BestOfSeed,
    #[error("`best_of` != 1 is not supported when streaming tokens")]
    BestOfStream,
Nicolas Patry's avatar
Nicolas Patry committed
581
582
583
584
    #[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
    TopNTokens(u32, u32),
    #[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
    TopNTokensDisabled,
585
586
    #[error("`decoder_input_details` == true is not supported when streaming tokens")]
    PrefillDetailsStream,
587
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
588
    Temperature,
589
    #[error("`repetition_penalty` must be strictly positive")]
590
    RepetitionPenalty,
591
592
    #[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
    FrequencyPenalty,
593
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
594
    TopP,
595
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
596
    TopK,
597
598
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
599
600
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
601
602
    #[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
    UnsetMaxNewTokens,
603
    #[error("`max_new_tokens` must be strictly positive")]
604
605
606
    NegativeMaxNewTokens,
    #[error("`max_new_tokens` must be <= {0}. Given: {1}")]
    MaxNewTokens(usize, u32),
607
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
608
    MaxTotalTokens(usize, usize, u32),
609
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
610
    InputLength(usize, usize),
611
    #[error("`inputs` cannot be empty")]
612
    EmptyInput,
613
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
614
    StopSequence(usize, usize),
615
616
    #[error("tokenizer error {0}")]
    Tokenizer(String),
drbh's avatar
drbh committed
617
618
    #[error("grammar is not supported")]
    Grammar,
619
620
    #[error("grammar is not valid: {0}")]
    InvalidGrammar(String),
621
622
623
624
625
626
627
628
629
630
    #[error("base64 encoding is invalid: {0}")]
    InvalidBase64(#[from] base64::DecodeError),
    #[error("invalid image: {0}")]
    InvalidImage(#[from] image::ImageError),
    #[error("invalid integer: {0}")]
    InvalidInt(#[from] core::num::TryFromIntError),
    #[error("invalid image content: {0}")]
    InvalidImageContent(String),
    #[error("Could not fetch image: {0}")]
    FailedFetchImage(#[from] reqwest::Error),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
631
}
632
633

#[cfg(test)]
634
mod tests {
635
    use super::*;
636
637
    use crate::default_parameters;
    use crate::tests::get_tokenizer;
638
639

    #[tokio::test]
640
    async fn test_validation_max_new_tokens() {
641
642
643
        let tokenizer = None;
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
644
645
646
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
647
        let workers = 1;
drbh's avatar
drbh committed
648
        let disable_grammar_support = true;
649
        let config = None;
650
651
652
        let validation = Validation::new(
            workers,
            tokenizer,
653
            config,
654
655
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
656
            max_top_n_tokens,
657
658
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
659
            disable_grammar_support,
660
        );
661
662

        let max_new_tokens = 10;
663
        match validation
664
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
665
666
            .await
        {
667
            Err(ValidationError::MaxNewTokens(1, 10)) => (),
668
            _ => panic!("Unexpected not max new tokens"),
669
670
671
672
        }
    }

    #[tokio::test]
673
    async fn test_validation_input_length() {
674
675
676
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
677
678
679
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
drbh's avatar
drbh committed
680
        let disable_grammar_support = true;
681
        let workers = 1;
682
        let config = None;
683
684
685
        let validation = Validation::new(
            workers,
            tokenizer,
686
            config,
687
688
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
689
            max_top_n_tokens,
690
691
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
692
            disable_grammar_support,
693
        );
694
695

        let max_new_tokens = 10;
696
        match validation
697
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
698
699
            .await
        {
Nicolas Patry's avatar
Nicolas Patry committed
700
            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
701
            _ => panic!("Unexpected not max new tokens"),
702
703
        }
    }
704
705

    #[tokio::test]
706
    async fn test_validation_best_of_sampling() {
707
708
709
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
710
711
712
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
713
        let workers = 1;
drbh's avatar
drbh committed
714
        let disable_grammar_support = true;
715
        let config = None;
716
717
718
        let validation = Validation::new(
            workers,
            tokenizer,
719
            config,
720
721
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
722
            max_top_n_tokens,
723
724
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
725
            disable_grammar_support,
726
727
728
729
730
731
732
733
734
735
736
737
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    best_of: Some(2),
                    do_sample: false,
                    ..default_parameters()
                },
            })
            .await
        {
738
            Err(ValidationError::BestOfSampling) => (),
739
            _ => panic!("Unexpected not best of sampling"),
740
741
742
743
        }
    }

    #[tokio::test]
744
    async fn test_validation_top_p() {
745
746
747
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
748
749
        let max_top_n_tokens = 4;
        let max_input_length = 5;
750
        let max_total_tokens = 106;
751
        let workers = 1;
drbh's avatar
drbh committed
752
        let disable_grammar_support = true;
753
        let config = None;
754
755
756
        let validation = Validation::new(
            workers,
            tokenizer,
757
            config,
758
759
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
760
            max_top_n_tokens,
761
762
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
763
            disable_grammar_support,
764
765
766
767
768
769
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(1.0),
770
                    max_new_tokens: Some(5),
771
772
773
774
775
                    ..default_parameters()
                },
            })
            .await
        {
776
            Err(ValidationError::TopP) => (),
777
            _ => panic!("Unexpected top_p"),
778
779
        }

780
781
782
783
784
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(0.99),
785
                    max_new_tokens: Some(5),
786
787
788
789
790
                    ..default_parameters()
                },
            })
            .await
        {
791
            Ok(_) => (),
792
            _ => panic!("Unexpected top_p error"),
793
794
        }

795
796
797
798
799
        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: None,
800
                    max_new_tokens: Some(5),
801
802
803
804
805
                    ..default_parameters()
                },
            })
            .await
            .unwrap();
806
807
808
        // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
        assert_eq!(valid_request.parameters.top_p, 1.0);
    }
Nicolas Patry's avatar
Nicolas Patry committed
809
810
811
812
813
814
815
816

    #[tokio::test]
    async fn test_validation_top_n_tokens() {
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequences = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
817
        let max_total_tokens = 106;
Nicolas Patry's avatar
Nicolas Patry committed
818
        let workers = 1;
drbh's avatar
drbh committed
819
        let disable_grammar_support = true;
820
        let config = None;
Nicolas Patry's avatar
Nicolas Patry committed
821
822
823
        let validation = Validation::new(
            workers,
            tokenizer,
824
            config,
Nicolas Patry's avatar
Nicolas Patry committed
825
826
827
828
829
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
830
            disable_grammar_support,
Nicolas Patry's avatar
Nicolas Patry committed
831
832
833
834
835
836
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(5),
837
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
838
839
840
841
842
843
844
845
846
847
848
849
850
851
                    ..default_parameters()
                },
            })
            .await
        {
            Err(ValidationError::TopNTokens(4, 5)) => (),
            _ => panic!("Unexpected top_n_tokens"),
        }

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(4),
852
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
853
854
855
856
857
858
859
860
861
862
863
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(0),
864
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
865
866
867
868
869
870
871
872
873
874
875
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: None,
876
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
877
878
879
880
881
882
883
884
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        assert_eq!(valid_request.top_n_tokens, 0);
    }
885
}