validation.rs 25.1 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Payload validation logic
2
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
drbh's avatar
drbh committed
3
use crate::{GenerateParameters, GenerateRequest, GrammarType};
4
use jsonschema::{Draft, JSONSchema};
5
use rand::{thread_rng, Rng};
6
use serde_json::Value;
drbh's avatar
drbh committed
7
8
9
use text_generation_client::{
    GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
11
use tokenizers::tokenizer::Tokenizer;
12
use tokenizers::TruncationDirection;
OlivierDehaene's avatar
OlivierDehaene committed
13
use tokio::sync::mpsc;
14
use tokio::sync::oneshot;
15
use tracing::{instrument, Span};
16
use {once_cell::sync::Lazy, regex::Regex};
Olivier Dehaene's avatar
Olivier Dehaene committed
17

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
19
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
20
pub struct Validation {
21
    /// Validation parameters
22
    max_best_of: usize,
23
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
24
    max_top_n_tokens: u32,
25
26
    max_input_length: usize,
    max_total_tokens: usize,
drbh's avatar
drbh committed
27
    disable_grammar_support: bool,
28
    /// Channel to communicate with the background tokenization task
OlivierDehaene's avatar
OlivierDehaene committed
29
    sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
30
31
32
}

impl Validation {
OlivierDehaene's avatar
OlivierDehaene committed
33
    #[allow(clippy::too_many_arguments)]
34
35
    pub(crate) fn new(
        workers: usize,
36
        tokenizer: Option<Tokenizer>,
37
        max_best_of: usize,
38
        max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
39
        max_top_n_tokens: u32,
40
41
        max_input_length: usize,
        max_total_tokens: usize,
drbh's avatar
drbh committed
42
        disable_grammar_support: bool,
43
    ) -> Self {
44
45
        // If we have a fast tokenizer
        let sender = if let Some(tokenizer) = tokenizer {
OlivierDehaene's avatar
OlivierDehaene committed
46
47
48
            // Create round robin channel
            let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
            let mut senders = Vec::with_capacity(workers);
49
50
51
52

            // Create workers
            for _ in 0..workers {
                let tokenizer_clone = tokenizer.clone();
OlivierDehaene's avatar
OlivierDehaene committed
53
54
                let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
                senders.push(tokenizer_sender);
55
56
57

                // Spawn worker
                tokio::task::spawn_blocking(move || {
OlivierDehaene's avatar
OlivierDehaene committed
58
                    tokenizer_worker(tokenizer_clone, tokenizer_receiver)
59
60
                });
            }
OlivierDehaene's avatar
OlivierDehaene committed
61
62
63
64

            // Create tokenization round robin task
            tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));

65
66
67
68
69
70
71
72
            Some(validation_sender)
        } else {
            None
        };

        Self {
            max_best_of,
            sender,
73
            max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
74
            max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
75
            max_input_length,
76
            max_total_tokens,
drbh's avatar
drbh committed
77
            disable_grammar_support,
78
79
        }
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
80

81
    #[instrument(skip(self, inputs))]
82
    pub async fn tokenize(
83
84
85
        &self,
        inputs: String,
        truncate: Option<usize>,
86
    ) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
87
88
89
90
91
92
93
94
95
96
97
98
        // If we have a fast tokenizer
        if let Some(sender) = &self.sender {
            // Create response channel
            let (response_sender, response_receiver) = oneshot::channel();
            // Send request to the background validation task
            // Unwrap is safe here
            sender
                .send(((inputs, truncate), response_sender, Span::current()))
                .unwrap();

            // Await on response channel
            // Unwrap is safe here
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            let encoding = response_receiver.await.unwrap()?;
            Ok(Some(encoding))
        } else {
            Ok(None)
        }
    }

    #[instrument(skip(self, inputs))]
    async fn validate_input(
        &self,
        inputs: String,
        truncate: Option<usize>,
        max_new_tokens: Option<u32>,
    ) -> Result<(String, usize, u32), ValidationError> {
        // If we have a fast tokenizer
        if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
            // Create response channel
            let input_length = encoding.len();
117
118

            // Get total tokens
119
120
121
122
123
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
                self.max_total_tokens.saturating_sub(input_length) as u32
            };
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
            let total_tokens = input_length + max_new_tokens as usize;

            // Validate MaxTotalTokens
            if total_tokens > self.max_total_tokens {
                return Err(ValidationError::MaxTotalTokens(
                    self.max_total_tokens,
                    input_length,
                    max_new_tokens,
                ));
            }

            // Validate InputLength
            if input_length > self.max_input_length {
                return Err(ValidationError::InputLength(
                    self.max_input_length,
                    input_length,
                ));
            }

            metrics::histogram!("tgi_request_input_length", input_length as f64);
144
            Ok((inputs, input_length, max_new_tokens))
145
146
147
148
149
150
        }
        // Return inputs without validation
        else {
            // In this case, we don't know the real length in tokens of the inputs
            // However, the inputs will be truncated by the python servers
            // We make sure that truncate + max_new_tokens <= self.max_total_tokens
151
152
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
OlivierDehaene's avatar
OlivierDehaene committed
153
154
            } else if let Some(truncate) = truncate {
                self.max_total_tokens.saturating_sub(truncate) as u32
155
            } else {
OlivierDehaene's avatar
OlivierDehaene committed
156
                return Err(ValidationError::UnsetMaxNewTokens);
157
            };
158
            let input_length = truncate.unwrap_or(self.max_input_length);
159
160

            // Validate MaxNewTokens
161
            if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
162
163
164
165
166
167
                return Err(ValidationError::MaxNewTokens(
                    self.max_total_tokens - self.max_input_length,
                    max_new_tokens,
                ));
            }

168
            Ok((inputs, input_length, max_new_tokens))
Olivier Dehaene's avatar
Olivier Dehaene committed
169
170
171
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
172
    /// Validate a payload and get the number of tokens in the input
173
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
174
175
176
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
177
    ) -> Result<ValidGenerateRequest, ValidationError> {
178
179
180
181
        let GenerateParameters {
            best_of,
            temperature,
            repetition_penalty,
182
            frequency_penalty,
183
184
185
186
187
188
189
190
191
            top_k,
            top_p,
            typical_p,
            do_sample,
            max_new_tokens,
            stop: stop_sequences,
            truncate,
            seed,
            watermark,
192
            decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
193
            top_n_tokens,
drbh's avatar
drbh committed
194
            grammar,
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
            ..
        } = request.parameters;

        // sampling must be true when best_of > 1
        let best_of = best_of.unwrap_or(1);
        let sampling = do_sample
            || temperature.is_some()
            || top_k.is_some()
            || top_p.is_some()
            || typical_p.is_some();

        if best_of > 1 && !sampling {
            return Err(BestOfSampling);
        }

        let temperature = temperature.unwrap_or(1.0);
        if temperature <= 0.0 {
            return Err(ValidationError::Temperature);
        }

        let repetition_penalty = repetition_penalty.unwrap_or(1.0);
        if repetition_penalty <= 0.0 {
            return Err(ValidationError::RepetitionPenalty);
        }

220
221
222
223
224
        let frequency_penalty = frequency_penalty.unwrap_or(0.0);
        if !(-2.0..=2.0).contains(&frequency_penalty) {
            return Err(ValidationError::FrequencyPenalty);
        }

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        // Different because the proto default value is not a valid value
        // for the user
        let top_p = top_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TopP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let typical_p = typical_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TypicalP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let top_k: u32 = top_k
            .map(|value| {
                if value <= 0 {
                    return Err(ValidationError::TopK);
                }
                Ok(value as u32)
            })
            .unwrap_or(Ok(0))?;

254
        if max_new_tokens == Some(0) {
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
            return Err(ValidationError::NegativeMaxNewTokens);
        }

        if stop_sequences.len() > self.max_stop_sequences {
            return Err(ValidationError::StopSequence(
                self.max_stop_sequences,
                stop_sequences.len(),
            ));
        }

        // If seed is None, assign a random one
        let seed = match seed {
            None => thread_rng().gen(),
            Some(seed) => {
                if best_of > 1 {
                    return Err(BestOfSeed);
                }
                seed
            }
        };

Nicolas Patry's avatar
Nicolas Patry committed
276
277
278
279
280
281
282
283
284
        let top_n_tokens = top_n_tokens
            .map(|value| {
                if value > self.max_top_n_tokens {
                    return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
                }
                Ok(value)
            })
            .unwrap_or(Ok(0))?;

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        // Check if inputs is empty
        if request.inputs.is_empty() {
            return Err(EmptyInput);
        }

        // Check if truncate is strictly positive and less than max_input_length
        let truncate = truncate
            .map(|value| {
                if value == 0 || value > self.max_input_length {
                    return Err(ValidationError::Truncate(self.max_input_length, value));
                }
                Ok(Some(value))
            })
            .unwrap_or(Ok(None))?;

        // Validate inputs
301
        let (inputs, input_length, max_new_tokens) = self
302
303
304
            .validate_input(request.inputs, truncate, max_new_tokens)
            .await?;

drbh's avatar
drbh committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
        // NOTE: this is currently difficult because we need the tokenizer in Python to build
        // the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which
        // may be slow and memory intensive. Best case is to have a Rust implementation of the FSM
        // compiler and use that to build the FSM here.

        // Validate grammar and unpack the grammar and type for the proto message
        let (grammar, grammar_type) = match grammar {
            Some(grammar) => {
                // Ensure that grammar is not set if it's not supported
                if self.disable_grammar_support {
                    return Err(ValidationError::Grammar);
                }
                match grammar {
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
                    GrammarType::Json(json) => {
                        let json = match json {
                            // if value is a string, we need to parse it again to make sure its
                            // a valid json
                            Value::String(s) => serde_json::from_str(&s)
                                .map_err(|e| ValidationError::InvalidGrammar(e.to_string())),
                            Value::Object(_) => Ok(json),
                            _ => Err(ValidationError::Grammar),
                        }?;

                        // Check if the json is a valid JSONSchema
                        JSONSchema::options()
                            .with_draft(Draft::Draft202012)
                            .compile(&json)
                            .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;

                        (
                            // Serialize json to string
                            serde_json::to_string(&json)
                                .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
                            ProtoGrammarType::Json.into(),
                        )
                    }
drbh's avatar
drbh committed
342
343
344
345
346
347
                    GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
                }
            }
            None => (String::new(), ProtoGrammarType::None.into()),
        };

348
349
350
        let parameters = NextTokenChooserParameters {
            temperature,
            repetition_penalty,
351
            frequency_penalty,
352
353
354
355
356
357
            top_k,
            top_p,
            typical_p,
            do_sample,
            seed,
            watermark,
drbh's avatar
drbh committed
358
359
            grammar,
            grammar_type,
360
361
362
363
364
365
366
367
368
369
370
        };
        let stopping_parameters = StoppingCriteriaParameters {
            max_new_tokens,
            stop_sequences,
            ignore_eos_token: false,
        };

        metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

        Ok(ValidGenerateRequest {
            inputs,
371
            decoder_input_details,
372
            input_length: input_length as u32,
373
374
375
            truncate: truncate.unwrap_or(self.max_input_length) as u32,
            parameters,
            stopping_parameters,
Nicolas Patry's avatar
Nicolas Patry committed
376
            top_n_tokens,
377
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
378
    }
379
380
381
382
383
384
385
386
387
388
389
390
391
392

    /// Validate the best_of parameter
    #[instrument(skip_all)]
    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
        if self.max_best_of == 1 && best_of != 1 {
            return Err(ValidationError::BestOfDisabled);
        }

        if best_of > self.max_best_of {
            return Err(ValidationError::BestOf(self.max_best_of, best_of));
        }

        Ok(best_of)
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
393
394
}

OlivierDehaene's avatar
OlivierDehaene committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
/// Round robin tokenization task
async fn round_robin_task(
    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
    senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
) {
    loop {
        for sender in &senders {
            match receiver.recv().await {
                None => return,
                Some(request) => sender.send(request).unwrap(),
            };
        }
    }
}

410
/// Start tokenization workers
OlivierDehaene's avatar
OlivierDehaene committed
411
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
412
    // Loop over requests
413
414
415
416
    let is_multimodal = {
        let vocab = tokenizer.get_vocab(true);
        vocab.contains_key("<image>")
    };
OlivierDehaene's avatar
OlivierDehaene committed
417
    while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
418
419
        parent_span.in_scope(|| {
            response_tx
420
                .send(prepare_input(inputs, truncate, &tokenizer, is_multimodal))
421
422
                .unwrap_or(())
        })
423
424
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
425

426
427
/// Get input length and optionally truncate it
fn prepare_input(
428
    mut inputs: String,
429
    truncate: Option<usize>,
430
    tokenizer: &Tokenizer,
431
    is_multimodal: bool,
432
) -> Result<(tokenizers::Encoding, String), ValidationError> {
433
434
435
436
437
438
    let simplified_query = if is_multimodal {
        static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
        RE.replace_all(&inputs, "<image>").into()
    } else {
        inputs.clone()
    };
439
440
    // Get the number of tokens in the input
    let mut encoding = tokenizer
441
        .encode(simplified_query, true)
442
443
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

444
    // Optionally truncate
445
    if let Some(truncate) = truncate {
446
        if truncate < encoding.len() && !is_multimodal {
447
            encoding.truncate(truncate, 0, TruncationDirection::Left);
448
            inputs = tokenizer
449
                .decode(encoding.get_ids(), false)
450
451
                .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
        }
452
    }
453

454
    Ok((encoding, inputs))
Olivier Dehaene's avatar
Olivier Dehaene committed
455
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
456

457
458
type TokenizerRequest = (
    (String, Option<usize>),
459
    oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
460
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
461
462
);

463
#[derive(Debug, Clone)]
464
465
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
466
    pub input_length: u32,
467
    pub truncate: u32,
468
    pub decoder_input_details: bool,
469
470
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
Nicolas Patry's avatar
Nicolas Patry committed
471
    pub top_n_tokens: u32,
472
473
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
474
475
#[derive(Error, Debug)]
pub enum ValidationError {
476
477
478
479
480
481
482
483
484
485
    #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
    BestOf(usize, usize),
    #[error("`best_of` != 1 is not allowed for this endpoint")]
    BestOfDisabled,
    #[error("you must use sampling when `best_of` is > 1")]
    BestOfSampling,
    #[error("`seed` must not be set when `best_of` > 1")]
    BestOfSeed,
    #[error("`best_of` != 1 is not supported when streaming tokens")]
    BestOfStream,
Nicolas Patry's avatar
Nicolas Patry committed
486
487
488
489
    #[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
    TopNTokens(u32, u32),
    #[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
    TopNTokensDisabled,
490
491
    #[error("`decoder_input_details` == true is not supported when streaming tokens")]
    PrefillDetailsStream,
492
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
493
    Temperature,
494
    #[error("`repetition_penalty` must be strictly positive")]
495
    RepetitionPenalty,
496
497
    #[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
    FrequencyPenalty,
498
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
499
    TopP,
500
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
501
    TopK,
502
503
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
504
505
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
506
507
    #[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
    UnsetMaxNewTokens,
508
    #[error("`max_new_tokens` must be strictly positive")]
509
510
511
    NegativeMaxNewTokens,
    #[error("`max_new_tokens` must be <= {0}. Given: {1}")]
    MaxNewTokens(usize, u32),
512
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
513
    MaxTotalTokens(usize, usize, u32),
514
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
515
    InputLength(usize, usize),
516
    #[error("`inputs` cannot be empty")]
517
    EmptyInput,
518
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
519
    StopSequence(usize, usize),
520
521
    #[error("tokenizer error {0}")]
    Tokenizer(String),
drbh's avatar
drbh committed
522
523
    #[error("grammar is not supported")]
    Grammar,
524
525
    #[error("grammar is not valid: {0}")]
    InvalidGrammar(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
526
}
527
528

#[cfg(test)]
529
mod tests {
530
    use super::*;
531
532
    use crate::default_parameters;
    use crate::tests::get_tokenizer;
533
534

    #[tokio::test]
535
    async fn test_validation_max_new_tokens() {
536
537
538
        let tokenizer = None;
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
539
540
541
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
542
        let workers = 1;
drbh's avatar
drbh committed
543
        let disable_grammar_support = true;
544
545
546
547
548
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
549
            max_top_n_tokens,
550
551
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
552
            disable_grammar_support,
553
        );
554
555

        let max_new_tokens = 10;
556
        match validation
557
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
558
559
            .await
        {
560
            Err(ValidationError::MaxNewTokens(1, 10)) => (),
561
            _ => panic!("Unexpected not max new tokens"),
562
563
564
565
        }
    }

    #[tokio::test]
566
    async fn test_validation_input_length() {
567
568
569
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
570
571
572
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
drbh's avatar
drbh committed
573
        let disable_grammar_support = true;
574
        let workers = 1;
575
576
577
578
579
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
580
            max_top_n_tokens,
581
582
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
583
            disable_grammar_support,
584
        );
585
586

        let max_new_tokens = 10;
587
        match validation
588
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
589
590
            .await
        {
Nicolas Patry's avatar
Nicolas Patry committed
591
            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
592
            _ => panic!("Unexpected not max new tokens"),
593
594
        }
    }
595
596

    #[tokio::test]
597
    async fn test_validation_best_of_sampling() {
598
599
600
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
601
602
603
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
604
        let workers = 1;
drbh's avatar
drbh committed
605
        let disable_grammar_support = true;
606
607
608
609
610
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
611
            max_top_n_tokens,
612
613
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
614
            disable_grammar_support,
615
616
617
618
619
620
621
622
623
624
625
626
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    best_of: Some(2),
                    do_sample: false,
                    ..default_parameters()
                },
            })
            .await
        {
627
            Err(ValidationError::BestOfSampling) => (),
628
            _ => panic!("Unexpected not best of sampling"),
629
630
631
632
        }
    }

    #[tokio::test]
633
    async fn test_validation_top_p() {
634
635
636
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
637
638
        let max_top_n_tokens = 4;
        let max_input_length = 5;
639
        let max_total_tokens = 106;
640
        let workers = 1;
drbh's avatar
drbh committed
641
        let disable_grammar_support = true;
642
643
644
645
646
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
647
            max_top_n_tokens,
648
649
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
650
            disable_grammar_support,
651
652
653
654
655
656
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(1.0),
657
                    max_new_tokens: Some(5),
658
659
660
661
662
                    ..default_parameters()
                },
            })
            .await
        {
663
            Err(ValidationError::TopP) => (),
664
            _ => panic!("Unexpected top_p"),
665
666
        }

667
668
669
670
671
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(0.99),
672
                    max_new_tokens: Some(5),
673
674
675
676
677
                    ..default_parameters()
                },
            })
            .await
        {
678
            Ok(_) => (),
679
            _ => panic!("Unexpected top_p error"),
680
681
        }

682
683
684
685
686
        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: None,
687
                    max_new_tokens: Some(5),
688
689
690
691
692
                    ..default_parameters()
                },
            })
            .await
            .unwrap();
693
694
695
        // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
        assert_eq!(valid_request.parameters.top_p, 1.0);
    }
Nicolas Patry's avatar
Nicolas Patry committed
696
697
698
699
700
701
702
703

    #[tokio::test]
    async fn test_validation_top_n_tokens() {
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequences = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
704
        let max_total_tokens = 106;
Nicolas Patry's avatar
Nicolas Patry committed
705
        let workers = 1;
drbh's avatar
drbh committed
706
        let disable_grammar_support = true;
Nicolas Patry's avatar
Nicolas Patry committed
707
708
709
710
711
712
713
714
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
715
            disable_grammar_support,
Nicolas Patry's avatar
Nicolas Patry committed
716
717
718
719
720
721
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(5),
722
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
723
724
725
726
727
728
729
730
731
732
733
734
735
736
                    ..default_parameters()
                },
            })
            .await
        {
            Err(ValidationError::TopNTokens(4, 5)) => (),
            _ => panic!("Unexpected top_n_tokens"),
        }

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(4),
737
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
738
739
740
741
742
743
744
745
746
747
748
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(0),
749
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
750
751
752
753
754
755
756
757
758
759
760
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: None,
761
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
762
763
764
765
766
767
768
769
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        assert_eq!(valid_request.top_n_tokens, 0);
    }
770
}