validation.rs 23.6 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Payload validation logic
2
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
drbh's avatar
drbh committed
3
use crate::{GenerateParameters, GenerateRequest, GrammarType};
4
use rand::{thread_rng, Rng};
drbh's avatar
drbh committed
5
6
7
use text_generation_client::{
    GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
Olivier Dehaene's avatar
Olivier Dehaene committed
8
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
9
use tokenizers::tokenizer::Tokenizer;
10
use tokenizers::TruncationDirection;
OlivierDehaene's avatar
OlivierDehaene committed
11
use tokio::sync::mpsc;
12
use tokio::sync::oneshot;
13
use tracing::{instrument, Span};
Olivier Dehaene's avatar
Olivier Dehaene committed
14

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
15
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
16
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
17
pub struct Validation {
18
    /// Validation parameters
19
    max_best_of: usize,
20
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
21
    max_top_n_tokens: u32,
22
23
    max_input_length: usize,
    max_total_tokens: usize,
drbh's avatar
drbh committed
24
    disable_grammar_support: bool,
25
    /// Channel to communicate with the background tokenization task
OlivierDehaene's avatar
OlivierDehaene committed
26
    sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
27
28
29
}

impl Validation {
OlivierDehaene's avatar
OlivierDehaene committed
30
    #[allow(clippy::too_many_arguments)]
31
32
    pub(crate) fn new(
        workers: usize,
33
        tokenizer: Option<Tokenizer>,
34
        max_best_of: usize,
35
        max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
36
        max_top_n_tokens: u32,
37
38
        max_input_length: usize,
        max_total_tokens: usize,
drbh's avatar
drbh committed
39
        disable_grammar_support: bool,
40
    ) -> Self {
41
42
        // If we have a fast tokenizer
        let sender = if let Some(tokenizer) = tokenizer {
OlivierDehaene's avatar
OlivierDehaene committed
43
44
45
            // Create round robin channel
            let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
            let mut senders = Vec::with_capacity(workers);
46
47
48
49

            // Create workers
            for _ in 0..workers {
                let tokenizer_clone = tokenizer.clone();
OlivierDehaene's avatar
OlivierDehaene committed
50
51
                let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
                senders.push(tokenizer_sender);
52
53
54

                // Spawn worker
                tokio::task::spawn_blocking(move || {
OlivierDehaene's avatar
OlivierDehaene committed
55
                    tokenizer_worker(tokenizer_clone, tokenizer_receiver)
56
57
                });
            }
OlivierDehaene's avatar
OlivierDehaene committed
58
59
60
61

            // Create tokenization round robin task
            tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));

62
63
64
65
66
67
68
69
            Some(validation_sender)
        } else {
            None
        };

        Self {
            max_best_of,
            sender,
70
            max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
71
            max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
72
            max_input_length,
73
            max_total_tokens,
drbh's avatar
drbh committed
74
            disable_grammar_support,
75
76
        }
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
77

78
    #[instrument(skip(self, inputs))]
79
    pub async fn tokenize(
80
81
82
        &self,
        inputs: String,
        truncate: Option<usize>,
83
    ) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
84
85
86
87
88
89
90
91
92
93
94
95
        // If we have a fast tokenizer
        if let Some(sender) = &self.sender {
            // Create response channel
            let (response_sender, response_receiver) = oneshot::channel();
            // Send request to the background validation task
            // Unwrap is safe here
            sender
                .send(((inputs, truncate), response_sender, Span::current()))
                .unwrap();

            // Await on response channel
            // Unwrap is safe here
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            let encoding = response_receiver.await.unwrap()?;
            Ok(Some(encoding))
        } else {
            Ok(None)
        }
    }

    #[instrument(skip(self, inputs))]
    async fn validate_input(
        &self,
        inputs: String,
        truncate: Option<usize>,
        max_new_tokens: Option<u32>,
    ) -> Result<(String, usize, u32), ValidationError> {
        // If we have a fast tokenizer
        if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
            // Create response channel
            let input_length = encoding.len();
114
115

            // Get total tokens
116
117
118
119
120
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
                self.max_total_tokens.saturating_sub(input_length) as u32
            };
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
            let total_tokens = input_length + max_new_tokens as usize;

            // Validate MaxTotalTokens
            if total_tokens > self.max_total_tokens {
                return Err(ValidationError::MaxTotalTokens(
                    self.max_total_tokens,
                    input_length,
                    max_new_tokens,
                ));
            }

            // Validate InputLength
            if input_length > self.max_input_length {
                return Err(ValidationError::InputLength(
                    self.max_input_length,
                    input_length,
                ));
            }

            metrics::histogram!("tgi_request_input_length", input_length as f64);
141
            Ok((inputs, input_length, max_new_tokens))
142
143
144
145
146
147
        }
        // Return inputs without validation
        else {
            // In this case, we don't know the real length in tokens of the inputs
            // However, the inputs will be truncated by the python servers
            // We make sure that truncate + max_new_tokens <= self.max_total_tokens
148
149
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
OlivierDehaene's avatar
OlivierDehaene committed
150
151
            } else if let Some(truncate) = truncate {
                self.max_total_tokens.saturating_sub(truncate) as u32
152
            } else {
OlivierDehaene's avatar
OlivierDehaene committed
153
                return Err(ValidationError::UnsetMaxNewTokens);
154
            };
155
            let input_length = truncate.unwrap_or(self.max_input_length);
156
157

            // Validate MaxNewTokens
158
            if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
159
160
161
162
163
164
                return Err(ValidationError::MaxNewTokens(
                    self.max_total_tokens - self.max_input_length,
                    max_new_tokens,
                ));
            }

165
            Ok((inputs, input_length, max_new_tokens))
Olivier Dehaene's avatar
Olivier Dehaene committed
166
167
168
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
169
    /// Validate a payload and get the number of tokens in the input
170
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
171
172
173
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
174
    ) -> Result<ValidGenerateRequest, ValidationError> {
175
176
177
178
        let GenerateParameters {
            best_of,
            temperature,
            repetition_penalty,
179
            frequency_penalty,
180
181
182
183
184
185
186
187
188
            top_k,
            top_p,
            typical_p,
            do_sample,
            max_new_tokens,
            stop: stop_sequences,
            truncate,
            seed,
            watermark,
189
            decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
190
            top_n_tokens,
drbh's avatar
drbh committed
191
            grammar,
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
            ..
        } = request.parameters;

        // sampling must be true when best_of > 1
        let best_of = best_of.unwrap_or(1);
        let sampling = do_sample
            || temperature.is_some()
            || top_k.is_some()
            || top_p.is_some()
            || typical_p.is_some();

        if best_of > 1 && !sampling {
            return Err(BestOfSampling);
        }

        let temperature = temperature.unwrap_or(1.0);
        if temperature <= 0.0 {
            return Err(ValidationError::Temperature);
        }

        let repetition_penalty = repetition_penalty.unwrap_or(1.0);
        if repetition_penalty <= 0.0 {
            return Err(ValidationError::RepetitionPenalty);
        }

217
218
219
220
221
        let frequency_penalty = frequency_penalty.unwrap_or(0.0);
        if !(-2.0..=2.0).contains(&frequency_penalty) {
            return Err(ValidationError::FrequencyPenalty);
        }

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        // Different because the proto default value is not a valid value
        // for the user
        let top_p = top_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TopP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let typical_p = typical_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TypicalP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let top_k: u32 = top_k
            .map(|value| {
                if value <= 0 {
                    return Err(ValidationError::TopK);
                }
                Ok(value as u32)
            })
            .unwrap_or(Ok(0))?;

251
        if max_new_tokens == Some(0) {
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
            return Err(ValidationError::NegativeMaxNewTokens);
        }

        if stop_sequences.len() > self.max_stop_sequences {
            return Err(ValidationError::StopSequence(
                self.max_stop_sequences,
                stop_sequences.len(),
            ));
        }

        // If seed is None, assign a random one
        let seed = match seed {
            None => thread_rng().gen(),
            Some(seed) => {
                if best_of > 1 {
                    return Err(BestOfSeed);
                }
                seed
            }
        };

Nicolas Patry's avatar
Nicolas Patry committed
273
274
275
276
277
278
279
280
281
        let top_n_tokens = top_n_tokens
            .map(|value| {
                if value > self.max_top_n_tokens {
                    return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
                }
                Ok(value)
            })
            .unwrap_or(Ok(0))?;

282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        // Check if inputs is empty
        if request.inputs.is_empty() {
            return Err(EmptyInput);
        }

        // Check if truncate is strictly positive and less than max_input_length
        let truncate = truncate
            .map(|value| {
                if value == 0 || value > self.max_input_length {
                    return Err(ValidationError::Truncate(self.max_input_length, value));
                }
                Ok(Some(value))
            })
            .unwrap_or(Ok(None))?;

        // Validate inputs
298
        let (inputs, input_length, max_new_tokens) = self
299
300
301
            .validate_input(request.inputs, truncate, max_new_tokens)
            .await?;

drbh's avatar
drbh committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
        // NOTE: this is currently difficult because we need the tokenizer in Python to build
        // the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which
        // may be slow and memory intensive. Best case is to have a Rust implementation of the FSM
        // compiler and use that to build the FSM here.

        // Validate grammar and unpack the grammar and type for the proto message
        let (grammar, grammar_type) = match grammar {
            Some(grammar) => {
                // Ensure that grammar is not set if it's not supported
                if self.disable_grammar_support {
                    return Err(ValidationError::Grammar);
                }
                match grammar {
                    // currently both are handled the same way since compilation is done in Python
                    GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()),
                    GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
                }
            }
            None => (String::new(), ProtoGrammarType::None.into()),
        };

324
325
326
        let parameters = NextTokenChooserParameters {
            temperature,
            repetition_penalty,
327
            frequency_penalty,
328
329
330
331
332
333
            top_k,
            top_p,
            typical_p,
            do_sample,
            seed,
            watermark,
drbh's avatar
drbh committed
334
335
            grammar,
            grammar_type,
336
337
338
339
340
341
342
343
344
345
346
        };
        let stopping_parameters = StoppingCriteriaParameters {
            max_new_tokens,
            stop_sequences,
            ignore_eos_token: false,
        };

        metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

        Ok(ValidGenerateRequest {
            inputs,
347
            decoder_input_details,
348
            input_length: input_length as u32,
349
350
351
            truncate: truncate.unwrap_or(self.max_input_length) as u32,
            parameters,
            stopping_parameters,
Nicolas Patry's avatar
Nicolas Patry committed
352
            top_n_tokens,
353
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
354
    }
355
356
357
358
359
360
361
362
363
364
365
366
367
368

    /// Validate the best_of parameter
    #[instrument(skip_all)]
    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
        if self.max_best_of == 1 && best_of != 1 {
            return Err(ValidationError::BestOfDisabled);
        }

        if best_of > self.max_best_of {
            return Err(ValidationError::BestOf(self.max_best_of, best_of));
        }

        Ok(best_of)
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
369
370
}

OlivierDehaene's avatar
OlivierDehaene committed
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
/// Round robin tokenization task
async fn round_robin_task(
    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
    senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
) {
    loop {
        for sender in &senders {
            match receiver.recv().await {
                None => return,
                Some(request) => sender.send(request).unwrap(),
            };
        }
    }
}

386
/// Start tokenization workers
OlivierDehaene's avatar
OlivierDehaene committed
387
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
388
    // Loop over requests
OlivierDehaene's avatar
OlivierDehaene committed
389
    while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
390
391
        parent_span.in_scope(|| {
            response_tx
392
                .send(prepare_input(inputs, truncate, &tokenizer))
393
394
                .unwrap_or(())
        })
395
396
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
397

398
399
/// Get input length and optionally truncate it
fn prepare_input(
400
    mut inputs: String,
401
    truncate: Option<usize>,
402
    tokenizer: &Tokenizer,
403
) -> Result<(tokenizers::Encoding, String), ValidationError> {
404
405
    // Get the number of tokens in the input
    let mut encoding = tokenizer
406
        .encode(inputs.clone(), true)
407
408
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

409
    // Optionally truncate
410
411
    if let Some(truncate) = truncate {
        if truncate < encoding.len() {
412
            encoding.truncate(truncate, 0, TruncationDirection::Left);
413
            inputs = tokenizer
414
                .decode(encoding.get_ids(), false)
415
416
                .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
        }
417
    }
418

419
    Ok((encoding, inputs))
Olivier Dehaene's avatar
Olivier Dehaene committed
420
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
421

422
423
type TokenizerRequest = (
    (String, Option<usize>),
424
    oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
425
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
426
427
);

428
#[derive(Debug, Clone)]
429
430
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
431
    pub input_length: u32,
432
    pub truncate: u32,
433
    pub decoder_input_details: bool,
434
435
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
Nicolas Patry's avatar
Nicolas Patry committed
436
    pub top_n_tokens: u32,
437
438
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
439
440
#[derive(Error, Debug)]
pub enum ValidationError {
441
442
443
444
445
446
447
448
449
450
    #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
    BestOf(usize, usize),
    #[error("`best_of` != 1 is not allowed for this endpoint")]
    BestOfDisabled,
    #[error("you must use sampling when `best_of` is > 1")]
    BestOfSampling,
    #[error("`seed` must not be set when `best_of` > 1")]
    BestOfSeed,
    #[error("`best_of` != 1 is not supported when streaming tokens")]
    BestOfStream,
Nicolas Patry's avatar
Nicolas Patry committed
451
452
453
454
    #[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
    TopNTokens(u32, u32),
    #[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
    TopNTokensDisabled,
455
456
    #[error("`decoder_input_details` == true is not supported when streaming tokens")]
    PrefillDetailsStream,
457
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
458
    Temperature,
459
    #[error("`repetition_penalty` must be strictly positive")]
460
    RepetitionPenalty,
461
462
    #[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
    FrequencyPenalty,
463
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
464
    TopP,
465
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
466
    TopK,
467
468
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
469
470
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
471
472
    #[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
    UnsetMaxNewTokens,
473
    #[error("`max_new_tokens` must be strictly positive")]
474
475
476
    NegativeMaxNewTokens,
    #[error("`max_new_tokens` must be <= {0}. Given: {1}")]
    MaxNewTokens(usize, u32),
477
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
478
    MaxTotalTokens(usize, usize, u32),
479
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
480
    InputLength(usize, usize),
481
    #[error("`inputs` cannot be empty")]
482
    EmptyInput,
483
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
484
    StopSequence(usize, usize),
485
486
    #[error("tokenizer error {0}")]
    Tokenizer(String),
drbh's avatar
drbh committed
487
488
    #[error("grammar is not supported")]
    Grammar,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
489
}
490
491

#[cfg(test)]
492
mod tests {
493
    use super::*;
494
495
    use crate::default_parameters;
    use crate::tests::get_tokenizer;
496
497

    #[tokio::test]
498
    async fn test_validation_max_new_tokens() {
499
500
501
        let tokenizer = None;
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
502
503
504
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
505
        let workers = 1;
drbh's avatar
drbh committed
506
        let disable_grammar_support = true;
507
508
509
510
511
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
512
            max_top_n_tokens,
513
514
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
515
            disable_grammar_support,
516
        );
517
518

        let max_new_tokens = 10;
519
        match validation
520
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
521
522
            .await
        {
523
            Err(ValidationError::MaxNewTokens(1, 10)) => (),
524
            _ => panic!("Unexpected not max new tokens"),
525
526
527
528
        }
    }

    #[tokio::test]
529
    async fn test_validation_input_length() {
530
531
532
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
533
534
535
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
drbh's avatar
drbh committed
536
        let disable_grammar_support = true;
537
        let workers = 1;
538
539
540
541
542
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
543
            max_top_n_tokens,
544
545
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
546
            disable_grammar_support,
547
        );
548
549

        let max_new_tokens = 10;
550
        match validation
551
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
552
553
            .await
        {
Nicolas Patry's avatar
Nicolas Patry committed
554
            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
555
            _ => panic!("Unexpected not max new tokens"),
556
557
        }
    }
558
559

    #[tokio::test]
560
    async fn test_validation_best_of_sampling() {
561
562
563
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
564
565
566
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
567
        let workers = 1;
drbh's avatar
drbh committed
568
        let disable_grammar_support = true;
569
570
571
572
573
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
574
            max_top_n_tokens,
575
576
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
577
            disable_grammar_support,
578
579
580
581
582
583
584
585
586
587
588
589
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    best_of: Some(2),
                    do_sample: false,
                    ..default_parameters()
                },
            })
            .await
        {
590
            Err(ValidationError::BestOfSampling) => (),
591
            _ => panic!("Unexpected not best of sampling"),
592
593
594
595
        }
    }

    #[tokio::test]
596
    async fn test_validation_top_p() {
597
598
599
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
600
601
        let max_top_n_tokens = 4;
        let max_input_length = 5;
602
        let max_total_tokens = 106;
603
        let workers = 1;
drbh's avatar
drbh committed
604
        let disable_grammar_support = true;
605
606
607
608
609
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
610
            max_top_n_tokens,
611
612
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
613
            disable_grammar_support,
614
615
616
617
618
619
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(1.0),
620
                    max_new_tokens: Some(5),
621
622
623
624
625
                    ..default_parameters()
                },
            })
            .await
        {
626
            Err(ValidationError::TopP) => (),
627
            _ => panic!("Unexpected top_p"),
628
629
        }

630
631
632
633
634
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(0.99),
635
                    max_new_tokens: Some(5),
636
637
638
639
640
                    ..default_parameters()
                },
            })
            .await
        {
641
            Ok(_) => (),
642
            _ => panic!("Unexpected top_p error"),
643
644
        }

645
646
647
648
649
        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: None,
650
                    max_new_tokens: Some(5),
651
652
653
654
655
                    ..default_parameters()
                },
            })
            .await
            .unwrap();
656
657
658
        // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
        assert_eq!(valid_request.parameters.top_p, 1.0);
    }
Nicolas Patry's avatar
Nicolas Patry committed
659
660
661
662
663
664
665
666

    #[tokio::test]
    async fn test_validation_top_n_tokens() {
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequences = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
667
        let max_total_tokens = 106;
Nicolas Patry's avatar
Nicolas Patry committed
668
        let workers = 1;
drbh's avatar
drbh committed
669
        let disable_grammar_support = true;
Nicolas Patry's avatar
Nicolas Patry committed
670
671
672
673
674
675
676
677
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
678
            disable_grammar_support,
Nicolas Patry's avatar
Nicolas Patry committed
679
680
681
682
683
684
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(5),
685
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
686
687
688
689
690
691
692
693
694
695
696
697
698
699
                    ..default_parameters()
                },
            })
            .await
        {
            Err(ValidationError::TopNTokens(4, 5)) => (),
            _ => panic!("Unexpected top_n_tokens"),
        }

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(4),
700
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
701
702
703
704
705
706
707
708
709
710
711
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(0),
712
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
713
714
715
716
717
718
719
720
721
722
723
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: None,
724
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
725
726
727
728
729
730
731
732
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        assert_eq!(valid_request.top_n_tokens, 0);
    }
733
}