validation.rs 21.3 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Payload validation logic
2
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
3
use crate::{GenerateParameters, GenerateRequest};
4
use rand::{thread_rng, Rng};
5
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use tokenizers::tokenizer::Tokenizer;
8
use tokenizers::TruncationDirection;
OlivierDehaene's avatar
OlivierDehaene committed
9
use tokio::sync::mpsc;
10
use tokio::sync::oneshot;
11
use tracing::{instrument, Span};
Olivier Dehaene's avatar
Olivier Dehaene committed
12

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
14
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
15
pub struct Validation {
16
    /// Validation parameters
17
    max_best_of: usize,
18
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
19
    max_top_n_tokens: u32,
20
21
22
    max_input_length: usize,
    max_total_tokens: usize,
    /// Channel to communicate with the background tokenization task
OlivierDehaene's avatar
OlivierDehaene committed
23
    sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
24
25
26
}

impl Validation {
27
28
    pub(crate) fn new(
        workers: usize,
29
        tokenizer: Option<Tokenizer>,
30
        max_best_of: usize,
31
        max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
32
        max_top_n_tokens: u32,
33
34
35
        max_input_length: usize,
        max_total_tokens: usize,
    ) -> Self {
36
37
        // If we have a fast tokenizer
        let sender = if let Some(tokenizer) = tokenizer {
OlivierDehaene's avatar
OlivierDehaene committed
38
39
40
            // Create round robin channel
            let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
            let mut senders = Vec::with_capacity(workers);
41
42
43
44

            // Create workers
            for _ in 0..workers {
                let tokenizer_clone = tokenizer.clone();
OlivierDehaene's avatar
OlivierDehaene committed
45
46
                let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
                senders.push(tokenizer_sender);
47
48
49

                // Spawn worker
                tokio::task::spawn_blocking(move || {
OlivierDehaene's avatar
OlivierDehaene committed
50
                    tokenizer_worker(tokenizer_clone, tokenizer_receiver)
51
52
                });
            }
OlivierDehaene's avatar
OlivierDehaene committed
53
54
55
56

            // Create tokenization round robin task
            tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));

57
58
59
60
61
62
63
64
            Some(validation_sender)
        } else {
            None
        };

        Self {
            max_best_of,
            sender,
65
            max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
66
            max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
67
            max_input_length,
68
            max_total_tokens,
69
70
        }
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
71

72
    #[instrument(skip(self, inputs))]
73
    pub async fn tokenize(
74
75
76
        &self,
        inputs: String,
        truncate: Option<usize>,
77
    ) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
78
79
80
81
82
83
84
85
86
87
88
89
        // If we have a fast tokenizer
        if let Some(sender) = &self.sender {
            // Create response channel
            let (response_sender, response_receiver) = oneshot::channel();
            // Send request to the background validation task
            // Unwrap is safe here
            sender
                .send(((inputs, truncate), response_sender, Span::current()))
                .unwrap();

            // Await on response channel
            // Unwrap is safe here
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            let encoding = response_receiver.await.unwrap()?;
            Ok(Some(encoding))
        } else {
            Ok(None)
        }
    }

    #[instrument(skip(self, inputs))]
    async fn validate_input(
        &self,
        inputs: String,
        truncate: Option<usize>,
        max_new_tokens: Option<u32>,
    ) -> Result<(String, usize, u32), ValidationError> {
        // If we have a fast tokenizer
        if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
            // Create response channel
            let input_length = encoding.len();
108
109

            // Get total tokens
110
111
112
113
114
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
                self.max_total_tokens.saturating_sub(input_length) as u32
            };
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
            let total_tokens = input_length + max_new_tokens as usize;

            // Validate MaxTotalTokens
            if total_tokens > self.max_total_tokens {
                return Err(ValidationError::MaxTotalTokens(
                    self.max_total_tokens,
                    input_length,
                    max_new_tokens,
                ));
            }

            // Validate InputLength
            if input_length > self.max_input_length {
                return Err(ValidationError::InputLength(
                    self.max_input_length,
                    input_length,
                ));
            }

            metrics::histogram!("tgi_request_input_length", input_length as f64);
135
            Ok((inputs, input_length, max_new_tokens))
136
137
138
139
140
141
        }
        // Return inputs without validation
        else {
            // In this case, we don't know the real length in tokens of the inputs
            // However, the inputs will be truncated by the python servers
            // We make sure that truncate + max_new_tokens <= self.max_total_tokens
142
143
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
OlivierDehaene's avatar
OlivierDehaene committed
144
145
            } else if let Some(truncate) = truncate {
                self.max_total_tokens.saturating_sub(truncate) as u32
146
            } else {
OlivierDehaene's avatar
OlivierDehaene committed
147
                return Err(ValidationError::UnsetMaxNewTokens);
148
            };
149
            let input_length = truncate.unwrap_or(self.max_input_length);
150
151

            // Validate MaxNewTokens
152
            if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
153
154
155
156
157
158
                return Err(ValidationError::MaxNewTokens(
                    self.max_total_tokens - self.max_input_length,
                    max_new_tokens,
                ));
            }

159
            Ok((inputs, input_length, max_new_tokens))
Olivier Dehaene's avatar
Olivier Dehaene committed
160
161
162
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
163
    /// Validate a payload and get the number of tokens in the input
164
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
165
166
167
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
168
    ) -> Result<ValidGenerateRequest, ValidationError> {
169
170
171
172
173
174
175
176
177
178
179
180
181
        let GenerateParameters {
            best_of,
            temperature,
            repetition_penalty,
            top_k,
            top_p,
            typical_p,
            do_sample,
            max_new_tokens,
            stop: stop_sequences,
            truncate,
            seed,
            watermark,
182
            decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
183
            top_n_tokens,
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            ..
        } = request.parameters;

        // sampling must be true when best_of > 1
        let best_of = best_of.unwrap_or(1);
        let sampling = do_sample
            || temperature.is_some()
            || top_k.is_some()
            || top_p.is_some()
            || typical_p.is_some();

        if best_of > 1 && !sampling {
            return Err(BestOfSampling);
        }

        let temperature = temperature.unwrap_or(1.0);
        if temperature <= 0.0 {
            return Err(ValidationError::Temperature);
        }

        let repetition_penalty = repetition_penalty.unwrap_or(1.0);
        if repetition_penalty <= 0.0 {
            return Err(ValidationError::RepetitionPenalty);
        }

        // Different because the proto default value is not a valid value
        // for the user
        let top_p = top_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TopP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let typical_p = typical_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TypicalP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let top_k: u32 = top_k
            .map(|value| {
                if value <= 0 {
                    return Err(ValidationError::TopK);
                }
                Ok(value as u32)
            })
            .unwrap_or(Ok(0))?;

238
        if max_new_tokens == Some(0) {
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            return Err(ValidationError::NegativeMaxNewTokens);
        }

        if stop_sequences.len() > self.max_stop_sequences {
            return Err(ValidationError::StopSequence(
                self.max_stop_sequences,
                stop_sequences.len(),
            ));
        }

        // If seed is None, assign a random one
        let seed = match seed {
            None => thread_rng().gen(),
            Some(seed) => {
                if best_of > 1 {
                    return Err(BestOfSeed);
                }
                seed
            }
        };

Nicolas Patry's avatar
Nicolas Patry committed
260
261
262
263
264
265
266
267
268
        let top_n_tokens = top_n_tokens
            .map(|value| {
                if value > self.max_top_n_tokens {
                    return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
                }
                Ok(value)
            })
            .unwrap_or(Ok(0))?;

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        // Check if inputs is empty
        if request.inputs.is_empty() {
            return Err(EmptyInput);
        }

        // Check if truncate is strictly positive and less than max_input_length
        let truncate = truncate
            .map(|value| {
                if value == 0 || value > self.max_input_length {
                    return Err(ValidationError::Truncate(self.max_input_length, value));
                }
                Ok(Some(value))
            })
            .unwrap_or(Ok(None))?;

        // Validate inputs
285
        let (inputs, input_length, max_new_tokens) = self
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
            .validate_input(request.inputs, truncate, max_new_tokens)
            .await?;

        let parameters = NextTokenChooserParameters {
            temperature,
            repetition_penalty,
            top_k,
            top_p,
            typical_p,
            do_sample,
            seed,
            watermark,
        };
        let stopping_parameters = StoppingCriteriaParameters {
            max_new_tokens,
            stop_sequences,
            ignore_eos_token: false,
        };

        metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

        Ok(ValidGenerateRequest {
            inputs,
309
            decoder_input_details,
310
            input_length: input_length as u32,
311
312
313
            truncate: truncate.unwrap_or(self.max_input_length) as u32,
            parameters,
            stopping_parameters,
Nicolas Patry's avatar
Nicolas Patry committed
314
            top_n_tokens,
315
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
316
    }
317
318
319
320
321
322
323
324
325
326
327
328
329
330

    /// Validate the best_of parameter
    #[instrument(skip_all)]
    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
        if self.max_best_of == 1 && best_of != 1 {
            return Err(ValidationError::BestOfDisabled);
        }

        if best_of > self.max_best_of {
            return Err(ValidationError::BestOf(self.max_best_of, best_of));
        }

        Ok(best_of)
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
331
332
}

OlivierDehaene's avatar
OlivierDehaene committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
/// Round robin tokenization task
async fn round_robin_task(
    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
    senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
) {
    loop {
        for sender in &senders {
            match receiver.recv().await {
                None => return,
                Some(request) => sender.send(request).unwrap(),
            };
        }
    }
}

348
/// Start tokenization workers
OlivierDehaene's avatar
OlivierDehaene committed
349
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
350
    // Loop over requests
OlivierDehaene's avatar
OlivierDehaene committed
351
    while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
352
353
        parent_span.in_scope(|| {
            response_tx
354
                .send(prepare_input(inputs, truncate, &tokenizer))
355
356
                .unwrap_or(())
        })
357
358
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
359

360
361
/// Get input length and optionally truncate it
fn prepare_input(
362
    mut inputs: String,
363
    truncate: Option<usize>,
364
    tokenizer: &Tokenizer,
365
) -> Result<(tokenizers::Encoding, String), ValidationError> {
366
367
    // Get the number of tokens in the input
    let mut encoding = tokenizer
368
        .encode(inputs.clone(), true)
369
370
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

371
    // Optionally truncate
372
373
    if let Some(truncate) = truncate {
        if truncate < encoding.len() {
374
            encoding.truncate(truncate, 0, TruncationDirection::Left);
375
            inputs = tokenizer
376
                .decode(encoding.get_ids(), false)
377
378
                .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
        }
379
    }
380

381
    Ok((encoding, inputs))
Olivier Dehaene's avatar
Olivier Dehaene committed
382
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
383

384
385
type TokenizerRequest = (
    (String, Option<usize>),
386
    oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
387
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
388
389
);

390
#[derive(Debug, Clone)]
391
392
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
393
    pub input_length: u32,
394
    pub truncate: u32,
395
    pub decoder_input_details: bool,
396
397
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
Nicolas Patry's avatar
Nicolas Patry committed
398
    pub top_n_tokens: u32,
399
400
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
401
402
#[derive(Error, Debug)]
pub enum ValidationError {
403
404
405
406
407
408
409
410
411
412
    #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
    BestOf(usize, usize),
    #[error("`best_of` != 1 is not allowed for this endpoint")]
    BestOfDisabled,
    #[error("you must use sampling when `best_of` is > 1")]
    BestOfSampling,
    #[error("`seed` must not be set when `best_of` > 1")]
    BestOfSeed,
    #[error("`best_of` != 1 is not supported when streaming tokens")]
    BestOfStream,
Nicolas Patry's avatar
Nicolas Patry committed
413
414
415
416
    #[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
    TopNTokens(u32, u32),
    #[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
    TopNTokensDisabled,
417
418
    #[error("`decoder_input_details` == true is not supported when streaming tokens")]
    PrefillDetailsStream,
419
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
420
    Temperature,
421
    #[error("`repetition_penalty` must be strictly positive")]
422
    RepetitionPenalty,
423
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
424
    TopP,
425
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
426
    TopK,
427
428
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
429
430
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
431
432
    #[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
    UnsetMaxNewTokens,
433
    #[error("`max_new_tokens` must be strictly positive")]
434
435
436
    NegativeMaxNewTokens,
    #[error("`max_new_tokens` must be <= {0}. Given: {1}")]
    MaxNewTokens(usize, u32),
437
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
438
    MaxTotalTokens(usize, usize, u32),
439
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
440
    InputLength(usize, usize),
441
    #[error("`inputs` cannot be empty")]
442
    EmptyInput,
443
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
444
    StopSequence(usize, usize),
445
446
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
447
}
448
449

#[cfg(test)]
450
mod tests {
451
    use super::*;
452
453
    use crate::default_parameters;
    use crate::tests::get_tokenizer;
454
455

    #[tokio::test]
456
    async fn test_validation_max_new_tokens() {
457
458
459
        let tokenizer = None;
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
460
461
462
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
463
        let workers = 1;
464
465
466
467
468
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
469
            max_top_n_tokens,
470
471
472
            max_input_length,
            max_total_tokens,
        );
473
474

        let max_new_tokens = 10;
475
        match validation
476
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
477
478
            .await
        {
479
            Err(ValidationError::MaxNewTokens(1, 10)) => (),
480
            _ => panic!("Unexpected not max new tokens"),
481
482
483
484
        }
    }

    #[tokio::test]
485
    async fn test_validation_input_length() {
486
487
488
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
489
490
491
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
492
        let workers = 1;
493
494
495
496
497
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
498
            max_top_n_tokens,
499
500
501
            max_input_length,
            max_total_tokens,
        );
502
503

        let max_new_tokens = 10;
504
        match validation
505
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
506
507
            .await
        {
Nicolas Patry's avatar
Nicolas Patry committed
508
            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
509
            _ => panic!("Unexpected not max new tokens"),
510
511
        }
    }
512
513

    #[tokio::test]
514
    async fn test_validation_best_of_sampling() {
515
516
517
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
518
519
520
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
521
        let workers = 1;
522
523
524
525
526
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
527
            max_top_n_tokens,
528
529
530
531
532
533
534
535
536
537
538
539
540
541
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    best_of: Some(2),
                    do_sample: false,
                    ..default_parameters()
                },
            })
            .await
        {
542
            Err(ValidationError::BestOfSampling) => (),
543
            _ => panic!("Unexpected not best of sampling"),
544
545
546
547
        }
    }

    #[tokio::test]
548
    async fn test_validation_top_p() {
549
550
551
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
552
553
        let max_top_n_tokens = 4;
        let max_input_length = 5;
554
        let max_total_tokens = 106;
555
        let workers = 1;
556
557
558
559
560
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
561
            max_top_n_tokens,
562
563
564
565
566
567
568
569
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(1.0),
570
                    max_new_tokens: Some(5),
571
572
573
574
575
                    ..default_parameters()
                },
            })
            .await
        {
576
            Err(ValidationError::TopP) => (),
577
            _ => panic!("Unexpected top_p"),
578
579
        }

580
581
582
583
584
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(0.99),
585
                    max_new_tokens: Some(5),
586
587
588
589
590
                    ..default_parameters()
                },
            })
            .await
        {
591
            Ok(_) => (),
592
            _ => panic!("Unexpected top_p error"),
593
594
        }

595
596
597
598
599
        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: None,
600
                    max_new_tokens: Some(5),
601
602
603
604
605
                    ..default_parameters()
                },
            })
            .await
            .unwrap();
606
607
608
        // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
        assert_eq!(valid_request.parameters.top_p, 1.0);
    }
Nicolas Patry's avatar
Nicolas Patry committed
609
610
611
612
613
614
615
616

    #[tokio::test]
    async fn test_validation_top_n_tokens() {
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequences = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
617
        let max_total_tokens = 106;
Nicolas Patry's avatar
Nicolas Patry committed
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        let workers = 1;
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(5),
633
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
634
635
636
637
638
639
640
641
642
643
644
645
646
647
                    ..default_parameters()
                },
            })
            .await
        {
            Err(ValidationError::TopNTokens(4, 5)) => (),
            _ => panic!("Unexpected top_n_tokens"),
        }

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(4),
648
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
649
650
651
652
653
654
655
656
657
658
659
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(0),
660
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
661
662
663
664
665
666
667
668
669
670
671
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: None,
672
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
673
674
675
676
677
678
679
680
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        assert_eq!(valid_request.top_n_tokens, 0);
    }
681
}