validation.rs 21 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Payload validation logic
2
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
3
use crate::{GenerateParameters, GenerateRequest};
4
use rand::{thread_rng, Rng};
5
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use tokenizers::tokenizer::Tokenizer;
8
use tokenizers::TruncationDirection;
OlivierDehaene's avatar
OlivierDehaene committed
9
use tokio::sync::mpsc;
10
use tokio::sync::oneshot;
11
use tracing::{instrument, Span};
Olivier Dehaene's avatar
Olivier Dehaene committed
12

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
14
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
15
pub struct Validation {
16
    /// Validation parameters
17
    max_best_of: usize,
18
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
19
    max_top_n_tokens: u32,
20
21
22
    max_input_length: usize,
    max_total_tokens: usize,
    /// Channel to communicate with the background tokenization task
OlivierDehaene's avatar
OlivierDehaene committed
23
    sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
24
25
26
}

impl Validation {
27
28
    pub(crate) fn new(
        workers: usize,
29
        tokenizer: Option<Tokenizer>,
30
        max_best_of: usize,
31
        max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
32
        max_top_n_tokens: u32,
33
34
35
        max_input_length: usize,
        max_total_tokens: usize,
    ) -> Self {
36
37
        // If we have a fast tokenizer
        let sender = if let Some(tokenizer) = tokenizer {
OlivierDehaene's avatar
OlivierDehaene committed
38
39
40
            // Create round robin channel
            let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
            let mut senders = Vec::with_capacity(workers);
41
42
43
44

            // Create workers
            for _ in 0..workers {
                let tokenizer_clone = tokenizer.clone();
OlivierDehaene's avatar
OlivierDehaene committed
45
46
                let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
                senders.push(tokenizer_sender);
47
48
49

                // Spawn worker
                tokio::task::spawn_blocking(move || {
OlivierDehaene's avatar
OlivierDehaene committed
50
                    tokenizer_worker(tokenizer_clone, tokenizer_receiver)
51
52
                });
            }
OlivierDehaene's avatar
OlivierDehaene committed
53
54
55
56

            // Create tokenization round robin task
            tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));

57
58
59
60
61
62
63
64
            Some(validation_sender)
        } else {
            None
        };

        Self {
            max_best_of,
            sender,
65
            max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
66
            max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
67
            max_input_length,
68
            max_total_tokens,
69
70
        }
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
71

72
    #[instrument(skip(self, inputs))]
73
74
75
76
    async fn validate_input(
        &self,
        inputs: String,
        truncate: Option<usize>,
77
78
        max_new_tokens: Option<u32>,
    ) -> Result<(String, usize, u32), ValidationError> {
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        // If we have a fast tokenizer
        if let Some(sender) = &self.sender {
            // Create response channel
            let (response_sender, response_receiver) = oneshot::channel();
            // Send request to the background validation task
            // Unwrap is safe here
            sender
                .send(((inputs, truncate), response_sender, Span::current()))
                .unwrap();

            // Await on response channel
            // Unwrap is safe here
            let (inputs, input_length) = response_receiver.await.unwrap()?;

            // Get total tokens
94
95
96
97
98
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
                self.max_total_tokens.saturating_sub(input_length) as u32
            };
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
            let total_tokens = input_length + max_new_tokens as usize;

            // Validate MaxTotalTokens
            if total_tokens > self.max_total_tokens {
                return Err(ValidationError::MaxTotalTokens(
                    self.max_total_tokens,
                    input_length,
                    max_new_tokens,
                ));
            }

            // Validate InputLength
            if input_length > self.max_input_length {
                return Err(ValidationError::InputLength(
                    self.max_input_length,
                    input_length,
                ));
            }

            metrics::histogram!("tgi_request_input_length", input_length as f64);
119
            Ok((inputs, input_length, max_new_tokens))
120
121
122
123
124
125
        }
        // Return inputs without validation
        else {
            // In this case, we don't know the real length in tokens of the inputs
            // However, the inputs will be truncated by the python servers
            // We make sure that truncate + max_new_tokens <= self.max_total_tokens
126
127
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
OlivierDehaene's avatar
OlivierDehaene committed
128
129
            } else if let Some(truncate) = truncate {
                self.max_total_tokens.saturating_sub(truncate) as u32
130
            } else {
OlivierDehaene's avatar
OlivierDehaene committed
131
                return Err(ValidationError::UnsetMaxNewTokens);
132
            };
133
            let input_length = truncate.unwrap_or(self.max_input_length);
134
135

            // Validate MaxNewTokens
136
            if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
137
138
139
140
141
142
                return Err(ValidationError::MaxNewTokens(
                    self.max_total_tokens - self.max_input_length,
                    max_new_tokens,
                ));
            }

143
            Ok((inputs, input_length, max_new_tokens))
Olivier Dehaene's avatar
Olivier Dehaene committed
144
145
146
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
147
    /// Validate a payload and get the number of tokens in the input
148
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
149
150
151
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
152
    ) -> Result<ValidGenerateRequest, ValidationError> {
153
154
155
156
157
158
159
160
161
162
163
164
165
        let GenerateParameters {
            best_of,
            temperature,
            repetition_penalty,
            top_k,
            top_p,
            typical_p,
            do_sample,
            max_new_tokens,
            stop: stop_sequences,
            truncate,
            seed,
            watermark,
166
            decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
167
            top_n_tokens,
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
            ..
        } = request.parameters;

        // sampling must be true when best_of > 1
        let best_of = best_of.unwrap_or(1);
        let sampling = do_sample
            || temperature.is_some()
            || top_k.is_some()
            || top_p.is_some()
            || typical_p.is_some();

        if best_of > 1 && !sampling {
            return Err(BestOfSampling);
        }

        let temperature = temperature.unwrap_or(1.0);
        if temperature <= 0.0 {
            return Err(ValidationError::Temperature);
        }

        let repetition_penalty = repetition_penalty.unwrap_or(1.0);
        if repetition_penalty <= 0.0 {
            return Err(ValidationError::RepetitionPenalty);
        }

        // Different because the proto default value is not a valid value
        // for the user
        let top_p = top_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TopP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let typical_p = typical_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TypicalP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let top_k: u32 = top_k
            .map(|value| {
                if value <= 0 {
                    return Err(ValidationError::TopK);
                }
                Ok(value as u32)
            })
            .unwrap_or(Ok(0))?;

222
        if max_new_tokens == Some(0) {
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
            return Err(ValidationError::NegativeMaxNewTokens);
        }

        if stop_sequences.len() > self.max_stop_sequences {
            return Err(ValidationError::StopSequence(
                self.max_stop_sequences,
                stop_sequences.len(),
            ));
        }

        // If seed is None, assign a random one
        let seed = match seed {
            None => thread_rng().gen(),
            Some(seed) => {
                if best_of > 1 {
                    return Err(BestOfSeed);
                }
                seed
            }
        };

Nicolas Patry's avatar
Nicolas Patry committed
244
245
246
247
248
249
250
251
252
        let top_n_tokens = top_n_tokens
            .map(|value| {
                if value > self.max_top_n_tokens {
                    return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
                }
                Ok(value)
            })
            .unwrap_or(Ok(0))?;

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        // Check if inputs is empty
        if request.inputs.is_empty() {
            return Err(EmptyInput);
        }

        // Check if truncate is strictly positive and less than max_input_length
        let truncate = truncate
            .map(|value| {
                if value == 0 || value > self.max_input_length {
                    return Err(ValidationError::Truncate(self.max_input_length, value));
                }
                Ok(Some(value))
            })
            .unwrap_or(Ok(None))?;

        // Validate inputs
269
        let (inputs, input_length, max_new_tokens) = self
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            .validate_input(request.inputs, truncate, max_new_tokens)
            .await?;

        let parameters = NextTokenChooserParameters {
            temperature,
            repetition_penalty,
            top_k,
            top_p,
            typical_p,
            do_sample,
            seed,
            watermark,
        };
        let stopping_parameters = StoppingCriteriaParameters {
            max_new_tokens,
            stop_sequences,
            ignore_eos_token: false,
        };

        metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

        Ok(ValidGenerateRequest {
            inputs,
293
            decoder_input_details,
294
            input_length: input_length as u32,
295
296
297
            truncate: truncate.unwrap_or(self.max_input_length) as u32,
            parameters,
            stopping_parameters,
Nicolas Patry's avatar
Nicolas Patry committed
298
            top_n_tokens,
299
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
300
    }
301
302
303
304
305
306
307
308
309
310
311
312
313
314

    /// Validate the best_of parameter
    #[instrument(skip_all)]
    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
        if self.max_best_of == 1 && best_of != 1 {
            return Err(ValidationError::BestOfDisabled);
        }

        if best_of > self.max_best_of {
            return Err(ValidationError::BestOf(self.max_best_of, best_of));
        }

        Ok(best_of)
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
315
316
}

OlivierDehaene's avatar
OlivierDehaene committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
/// Round robin tokenization task
async fn round_robin_task(
    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
    senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
) {
    loop {
        for sender in &senders {
            match receiver.recv().await {
                None => return,
                Some(request) => sender.send(request).unwrap(),
            };
        }
    }
}

332
/// Start tokenization workers
OlivierDehaene's avatar
OlivierDehaene committed
333
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
334
    // Loop over requests
OlivierDehaene's avatar
OlivierDehaene committed
335
    while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
336
337
        parent_span.in_scope(|| {
            response_tx
338
                .send(prepare_input(inputs, truncate, &tokenizer))
339
340
                .unwrap_or(())
        })
341
342
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
343

344
345
346
347
/// Get input length and optionally truncate it
fn prepare_input(
    inputs: String,
    truncate: Option<usize>,
348
    tokenizer: &Tokenizer,
349
) -> Result<(String, usize), ValidationError> {
350
351
    // Get the number of tokens in the input
    let mut encoding = tokenizer
352
        .encode(inputs.clone(), true)
353
354
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

355
356
    // Optionally truncate
    let (inputs, input_length) = match truncate {
357
358
        // Truncate is some and < encoding length
        Some(truncate) if truncate < encoding.len() => {
359
360
361
            // truncate encoding and decode new inputs
            encoding.truncate(truncate, 0, TruncationDirection::Left);
            let inputs = tokenizer
362
                .decode(encoding.get_ids(), false)
363
364
365
366
367
                .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
            (inputs, encoding.len())
        }
        // Nothing to do
        _ => (inputs, encoding.len()),
368
369
    };

370
    Ok((inputs, input_length))
Olivier Dehaene's avatar
Olivier Dehaene committed
371
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
372

373
374
375
type TokenizerRequest = (
    (String, Option<usize>),
    oneshot::Sender<Result<(String, usize), ValidationError>>,
376
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
377
378
);

379
380
381
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
382
    pub input_length: u32,
383
    pub truncate: u32,
384
    pub decoder_input_details: bool,
385
386
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
Nicolas Patry's avatar
Nicolas Patry committed
387
    pub top_n_tokens: u32,
388
389
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
390
391
#[derive(Error, Debug)]
pub enum ValidationError {
392
393
394
395
396
397
398
399
400
401
    #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
    BestOf(usize, usize),
    #[error("`best_of` != 1 is not allowed for this endpoint")]
    BestOfDisabled,
    #[error("you must use sampling when `best_of` is > 1")]
    BestOfSampling,
    #[error("`seed` must not be set when `best_of` > 1")]
    BestOfSeed,
    #[error("`best_of` != 1 is not supported when streaming tokens")]
    BestOfStream,
Nicolas Patry's avatar
Nicolas Patry committed
402
403
404
405
    #[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
    TopNTokens(u32, u32),
    #[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
    TopNTokensDisabled,
406
407
    #[error("`decoder_input_details` == true is not supported when streaming tokens")]
    PrefillDetailsStream,
408
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
409
    Temperature,
410
    #[error("`repetition_penalty` must be strictly positive")]
411
    RepetitionPenalty,
412
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
413
    TopP,
414
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
415
    TopK,
416
417
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
418
419
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
420
421
    #[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
    UnsetMaxNewTokens,
422
    #[error("`max_new_tokens` must be strictly positive")]
423
424
425
    NegativeMaxNewTokens,
    #[error("`max_new_tokens` must be <= {0}. Given: {1}")]
    MaxNewTokens(usize, u32),
426
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
427
    MaxTotalTokens(usize, usize, u32),
428
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
429
    InputLength(usize, usize),
430
    #[error("`inputs` cannot be empty")]
431
    EmptyInput,
432
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
433
    StopSequence(usize, usize),
434
435
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
436
}
437
438

#[cfg(test)]
439
mod tests {
440
    use super::*;
441
442
    use crate::default_parameters;
    use crate::tests::get_tokenizer;
443
444

    #[tokio::test]
445
    async fn test_validation_max_new_tokens() {
446
447
448
        let tokenizer = None;
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
449
450
451
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
452
        let workers = 1;
453
454
455
456
457
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
458
            max_top_n_tokens,
459
460
461
            max_input_length,
            max_total_tokens,
        );
462
463

        let max_new_tokens = 10;
464
        match validation
465
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
466
467
            .await
        {
468
            Err(ValidationError::MaxNewTokens(1, 10)) => (),
469
            _ => panic!("Unexpected not max new tokens"),
470
471
472
473
        }
    }

    #[tokio::test]
474
    async fn test_validation_input_length() {
475
476
477
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
478
479
480
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
481
        let workers = 1;
482
483
484
485
486
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
487
            max_top_n_tokens,
488
489
490
            max_input_length,
            max_total_tokens,
        );
491
492

        let max_new_tokens = 10;
493
        match validation
494
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
495
496
            .await
        {
Nicolas Patry's avatar
Nicolas Patry committed
497
            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
498
            _ => panic!("Unexpected not max new tokens"),
499
500
        }
    }
501
502

    #[tokio::test]
503
    async fn test_validation_best_of_sampling() {
504
505
506
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
507
508
509
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
510
        let workers = 1;
511
512
513
514
515
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
516
            max_top_n_tokens,
517
518
519
520
521
522
523
524
525
526
527
528
529
530
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    best_of: Some(2),
                    do_sample: false,
                    ..default_parameters()
                },
            })
            .await
        {
531
            Err(ValidationError::BestOfSampling) => (),
532
            _ => panic!("Unexpected not best of sampling"),
533
534
535
536
        }
    }

    #[tokio::test]
537
    async fn test_validation_top_p() {
538
539
540
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
541
542
        let max_top_n_tokens = 4;
        let max_input_length = 5;
543
        let max_total_tokens = 106;
544
        let workers = 1;
545
546
547
548
549
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
550
            max_top_n_tokens,
551
552
553
554
555
556
557
558
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(1.0),
559
                    max_new_tokens: Some(5),
560
561
562
563
564
                    ..default_parameters()
                },
            })
            .await
        {
565
            Err(ValidationError::TopP) => (),
566
            _ => panic!("Unexpected top_p"),
567
568
        }

569
570
571
572
573
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(0.99),
574
                    max_new_tokens: Some(5),
575
576
577
578
579
                    ..default_parameters()
                },
            })
            .await
        {
580
            Ok(_) => (),
581
            _ => panic!("Unexpected top_p error"),
582
583
        }

584
585
586
587
588
        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: None,
589
                    max_new_tokens: Some(5),
590
591
592
593
594
                    ..default_parameters()
                },
            })
            .await
            .unwrap();
595
596
597
        // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
        assert_eq!(valid_request.parameters.top_p, 1.0);
    }
Nicolas Patry's avatar
Nicolas Patry committed
598
599
600
601
602
603
604
605

    #[tokio::test]
    async fn test_validation_top_n_tokens() {
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequences = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
606
        let max_total_tokens = 106;
Nicolas Patry's avatar
Nicolas Patry committed
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        let workers = 1;
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(5),
622
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
623
624
625
626
627
628
629
630
631
632
633
634
635
636
                    ..default_parameters()
                },
            })
            .await
        {
            Err(ValidationError::TopNTokens(4, 5)) => (),
            _ => panic!("Unexpected top_n_tokens"),
        }

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(4),
637
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
638
639
640
641
642
643
644
645
646
647
648
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(0),
649
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
650
651
652
653
654
655
656
657
658
659
660
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: None,
661
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
662
663
664
665
666
667
668
669
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        assert_eq!(valid_request.top_n_tokens, 0);
    }
670
}