validation.rs 21.7 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Payload validation logic
2
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
3
use crate::{GenerateParameters, GenerateRequest};
4
use rand::{thread_rng, Rng};
5
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use tokenizers::tokenizer::Tokenizer;
8
use tokenizers::TruncationDirection;
OlivierDehaene's avatar
OlivierDehaene committed
9
use tokio::sync::mpsc;
10
use tokio::sync::oneshot;
11
use tracing::{instrument, Span};
Olivier Dehaene's avatar
Olivier Dehaene committed
12

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
14
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
15
pub struct Validation {
16
    /// Validation parameters
17
    max_best_of: usize,
18
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
19
    max_top_n_tokens: u32,
20
21
22
    max_input_length: usize,
    max_total_tokens: usize,
    /// Channel to communicate with the background tokenization task
OlivierDehaene's avatar
OlivierDehaene committed
23
    sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
24
25
26
}

impl Validation {
27
28
    pub(crate) fn new(
        workers: usize,
29
        tokenizer: Option<Tokenizer>,
30
        max_best_of: usize,
31
        max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
32
        max_top_n_tokens: u32,
33
34
35
        max_input_length: usize,
        max_total_tokens: usize,
    ) -> Self {
36
37
        // If we have a fast tokenizer
        let sender = if let Some(tokenizer) = tokenizer {
OlivierDehaene's avatar
OlivierDehaene committed
38
39
40
            // Create round robin channel
            let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
            let mut senders = Vec::with_capacity(workers);
41
42
43
44

            // Create workers
            for _ in 0..workers {
                let tokenizer_clone = tokenizer.clone();
OlivierDehaene's avatar
OlivierDehaene committed
45
46
                let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
                senders.push(tokenizer_sender);
47
48
49

                // Spawn worker
                tokio::task::spawn_blocking(move || {
OlivierDehaene's avatar
OlivierDehaene committed
50
                    tokenizer_worker(tokenizer_clone, tokenizer_receiver)
51
52
                });
            }
OlivierDehaene's avatar
OlivierDehaene committed
53
54
55
56

            // Create tokenization round robin task
            tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));

57
58
59
60
61
62
63
64
            Some(validation_sender)
        } else {
            None
        };

        Self {
            max_best_of,
            sender,
65
            max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
66
            max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
67
            max_input_length,
68
            max_total_tokens,
69
70
        }
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
71

72
    #[instrument(skip(self, inputs))]
73
    pub async fn tokenize(
74
75
76
        &self,
        inputs: String,
        truncate: Option<usize>,
77
    ) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
78
79
80
81
82
83
84
85
86
87
88
89
        // If we have a fast tokenizer
        if let Some(sender) = &self.sender {
            // Create response channel
            let (response_sender, response_receiver) = oneshot::channel();
            // Send request to the background validation task
            // Unwrap is safe here
            sender
                .send(((inputs, truncate), response_sender, Span::current()))
                .unwrap();

            // Await on response channel
            // Unwrap is safe here
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            let encoding = response_receiver.await.unwrap()?;
            Ok(Some(encoding))
        } else {
            Ok(None)
        }
    }

    #[instrument(skip(self, inputs))]
    async fn validate_input(
        &self,
        inputs: String,
        truncate: Option<usize>,
        max_new_tokens: Option<u32>,
    ) -> Result<(String, usize, u32), ValidationError> {
        // If we have a fast tokenizer
        if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
            // Create response channel
            let input_length = encoding.len();
108
109

            // Get total tokens
110
111
112
113
114
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
                self.max_total_tokens.saturating_sub(input_length) as u32
            };
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
            let total_tokens = input_length + max_new_tokens as usize;

            // Validate MaxTotalTokens
            if total_tokens > self.max_total_tokens {
                return Err(ValidationError::MaxTotalTokens(
                    self.max_total_tokens,
                    input_length,
                    max_new_tokens,
                ));
            }

            // Validate InputLength
            if input_length > self.max_input_length {
                return Err(ValidationError::InputLength(
                    self.max_input_length,
                    input_length,
                ));
            }

            metrics::histogram!("tgi_request_input_length", input_length as f64);
135
            Ok((inputs, input_length, max_new_tokens))
136
137
138
139
140
141
        }
        // Return inputs without validation
        else {
            // In this case, we don't know the real length in tokens of the inputs
            // However, the inputs will be truncated by the python servers
            // We make sure that truncate + max_new_tokens <= self.max_total_tokens
142
143
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
OlivierDehaene's avatar
OlivierDehaene committed
144
145
            } else if let Some(truncate) = truncate {
                self.max_total_tokens.saturating_sub(truncate) as u32
146
            } else {
OlivierDehaene's avatar
OlivierDehaene committed
147
                return Err(ValidationError::UnsetMaxNewTokens);
148
            };
149
            let input_length = truncate.unwrap_or(self.max_input_length);
150
151

            // Validate MaxNewTokens
152
            if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
153
154
155
156
157
158
                return Err(ValidationError::MaxNewTokens(
                    self.max_total_tokens - self.max_input_length,
                    max_new_tokens,
                ));
            }

159
            Ok((inputs, input_length, max_new_tokens))
Olivier Dehaene's avatar
Olivier Dehaene committed
160
161
162
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
163
    /// Validate a payload and get the number of tokens in the input
164
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
165
166
167
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
168
    ) -> Result<ValidGenerateRequest, ValidationError> {
169
170
171
172
        let GenerateParameters {
            best_of,
            temperature,
            repetition_penalty,
173
            frequency_penalty,
174
175
176
177
178
179
180
181
182
            top_k,
            top_p,
            typical_p,
            do_sample,
            max_new_tokens,
            stop: stop_sequences,
            truncate,
            seed,
            watermark,
183
            decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
184
            top_n_tokens,
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
            ..
        } = request.parameters;

        // sampling must be true when best_of > 1
        let best_of = best_of.unwrap_or(1);
        let sampling = do_sample
            || temperature.is_some()
            || top_k.is_some()
            || top_p.is_some()
            || typical_p.is_some();

        if best_of > 1 && !sampling {
            return Err(BestOfSampling);
        }

        let temperature = temperature.unwrap_or(1.0);
        if temperature <= 0.0 {
            return Err(ValidationError::Temperature);
        }

        let repetition_penalty = repetition_penalty.unwrap_or(1.0);
        if repetition_penalty <= 0.0 {
            return Err(ValidationError::RepetitionPenalty);
        }

210
211
212
213
214
        let frequency_penalty = frequency_penalty.unwrap_or(0.0);
        if !(-2.0..=2.0).contains(&frequency_penalty) {
            return Err(ValidationError::FrequencyPenalty);
        }

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        // Different because the proto default value is not a valid value
        // for the user
        let top_p = top_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TopP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let typical_p = typical_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TypicalP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let top_k: u32 = top_k
            .map(|value| {
                if value <= 0 {
                    return Err(ValidationError::TopK);
                }
                Ok(value as u32)
            })
            .unwrap_or(Ok(0))?;

244
        if max_new_tokens == Some(0) {
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
            return Err(ValidationError::NegativeMaxNewTokens);
        }

        if stop_sequences.len() > self.max_stop_sequences {
            return Err(ValidationError::StopSequence(
                self.max_stop_sequences,
                stop_sequences.len(),
            ));
        }

        // If seed is None, assign a random one
        let seed = match seed {
            None => thread_rng().gen(),
            Some(seed) => {
                if best_of > 1 {
                    return Err(BestOfSeed);
                }
                seed
            }
        };

Nicolas Patry's avatar
Nicolas Patry committed
266
267
268
269
270
271
272
273
274
        let top_n_tokens = top_n_tokens
            .map(|value| {
                if value > self.max_top_n_tokens {
                    return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
                }
                Ok(value)
            })
            .unwrap_or(Ok(0))?;

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        // Check if inputs is empty
        if request.inputs.is_empty() {
            return Err(EmptyInput);
        }

        // Check if truncate is strictly positive and less than max_input_length
        let truncate = truncate
            .map(|value| {
                if value == 0 || value > self.max_input_length {
                    return Err(ValidationError::Truncate(self.max_input_length, value));
                }
                Ok(Some(value))
            })
            .unwrap_or(Ok(None))?;

        // Validate inputs
291
        let (inputs, input_length, max_new_tokens) = self
292
293
294
295
296
297
            .validate_input(request.inputs, truncate, max_new_tokens)
            .await?;

        let parameters = NextTokenChooserParameters {
            temperature,
            repetition_penalty,
298
            frequency_penalty,
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
            top_k,
            top_p,
            typical_p,
            do_sample,
            seed,
            watermark,
        };
        let stopping_parameters = StoppingCriteriaParameters {
            max_new_tokens,
            stop_sequences,
            ignore_eos_token: false,
        };

        metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

        Ok(ValidGenerateRequest {
            inputs,
316
            decoder_input_details,
317
            input_length: input_length as u32,
318
319
320
            truncate: truncate.unwrap_or(self.max_input_length) as u32,
            parameters,
            stopping_parameters,
Nicolas Patry's avatar
Nicolas Patry committed
321
            top_n_tokens,
322
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
323
    }
324
325
326
327
328
329
330
331
332
333
334
335
336
337

    /// Validate the best_of parameter
    #[instrument(skip_all)]
    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
        if self.max_best_of == 1 && best_of != 1 {
            return Err(ValidationError::BestOfDisabled);
        }

        if best_of > self.max_best_of {
            return Err(ValidationError::BestOf(self.max_best_of, best_of));
        }

        Ok(best_of)
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
338
339
}

OlivierDehaene's avatar
OlivierDehaene committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
/// Round robin tokenization task
async fn round_robin_task(
    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
    senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
) {
    loop {
        for sender in &senders {
            match receiver.recv().await {
                None => return,
                Some(request) => sender.send(request).unwrap(),
            };
        }
    }
}

355
/// Start tokenization workers
OlivierDehaene's avatar
OlivierDehaene committed
356
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
357
    // Loop over requests
OlivierDehaene's avatar
OlivierDehaene committed
358
    while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
359
360
        parent_span.in_scope(|| {
            response_tx
361
                .send(prepare_input(inputs, truncate, &tokenizer))
362
363
                .unwrap_or(())
        })
364
365
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
366

367
368
/// Get input length and optionally truncate it
fn prepare_input(
369
    mut inputs: String,
370
    truncate: Option<usize>,
371
    tokenizer: &Tokenizer,
372
) -> Result<(tokenizers::Encoding, String), ValidationError> {
373
374
    // Get the number of tokens in the input
    let mut encoding = tokenizer
375
        .encode(inputs.clone(), true)
376
377
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

378
    // Optionally truncate
379
380
    if let Some(truncate) = truncate {
        if truncate < encoding.len() {
381
            encoding.truncate(truncate, 0, TruncationDirection::Left);
382
            inputs = tokenizer
383
                .decode(encoding.get_ids(), false)
384
385
                .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
        }
386
    }
387

388
    Ok((encoding, inputs))
Olivier Dehaene's avatar
Olivier Dehaene committed
389
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
390

391
392
type TokenizerRequest = (
    (String, Option<usize>),
393
    oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
394
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
395
396
);

397
#[derive(Debug, Clone)]
398
399
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
400
    pub input_length: u32,
401
    pub truncate: u32,
402
    pub decoder_input_details: bool,
403
404
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
Nicolas Patry's avatar
Nicolas Patry committed
405
    pub top_n_tokens: u32,
406
407
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
408
409
#[derive(Error, Debug)]
pub enum ValidationError {
410
411
412
413
414
415
416
417
418
419
    #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
    BestOf(usize, usize),
    #[error("`best_of` != 1 is not allowed for this endpoint")]
    BestOfDisabled,
    #[error("you must use sampling when `best_of` is > 1")]
    BestOfSampling,
    #[error("`seed` must not be set when `best_of` > 1")]
    BestOfSeed,
    #[error("`best_of` != 1 is not supported when streaming tokens")]
    BestOfStream,
Nicolas Patry's avatar
Nicolas Patry committed
420
421
422
423
    #[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
    TopNTokens(u32, u32),
    #[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
    TopNTokensDisabled,
424
425
    #[error("`decoder_input_details` == true is not supported when streaming tokens")]
    PrefillDetailsStream,
426
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
427
    Temperature,
428
    #[error("`repetition_penalty` must be strictly positive")]
429
    RepetitionPenalty,
430
431
    #[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
    FrequencyPenalty,
432
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
433
    TopP,
434
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
435
    TopK,
436
437
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
438
439
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
440
441
    #[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
    UnsetMaxNewTokens,
442
    #[error("`max_new_tokens` must be strictly positive")]
443
444
445
    NegativeMaxNewTokens,
    #[error("`max_new_tokens` must be <= {0}. Given: {1}")]
    MaxNewTokens(usize, u32),
446
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
447
    MaxTotalTokens(usize, usize, u32),
448
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
449
    InputLength(usize, usize),
450
    #[error("`inputs` cannot be empty")]
451
    EmptyInput,
452
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
453
    StopSequence(usize, usize),
454
455
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
456
}
457
458

#[cfg(test)]
459
mod tests {
460
    use super::*;
461
462
    use crate::default_parameters;
    use crate::tests::get_tokenizer;
463
464

    #[tokio::test]
465
    async fn test_validation_max_new_tokens() {
466
467
468
        let tokenizer = None;
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
469
470
471
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
472
        let workers = 1;
473
474
475
476
477
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
478
            max_top_n_tokens,
479
480
481
            max_input_length,
            max_total_tokens,
        );
482
483

        let max_new_tokens = 10;
484
        match validation
485
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
486
487
            .await
        {
488
            Err(ValidationError::MaxNewTokens(1, 10)) => (),
489
            _ => panic!("Unexpected not max new tokens"),
490
491
492
493
        }
    }

    #[tokio::test]
494
    async fn test_validation_input_length() {
495
496
497
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
498
499
500
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
501
        let workers = 1;
502
503
504
505
506
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
507
            max_top_n_tokens,
508
509
510
            max_input_length,
            max_total_tokens,
        );
511
512

        let max_new_tokens = 10;
513
        match validation
514
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
515
516
            .await
        {
Nicolas Patry's avatar
Nicolas Patry committed
517
            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
518
            _ => panic!("Unexpected not max new tokens"),
519
520
        }
    }
521
522

    #[tokio::test]
523
    async fn test_validation_best_of_sampling() {
524
525
526
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
527
528
529
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
530
        let workers = 1;
531
532
533
534
535
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
536
            max_top_n_tokens,
537
538
539
540
541
542
543
544
545
546
547
548
549
550
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    best_of: Some(2),
                    do_sample: false,
                    ..default_parameters()
                },
            })
            .await
        {
551
            Err(ValidationError::BestOfSampling) => (),
552
            _ => panic!("Unexpected not best of sampling"),
553
554
555
556
        }
    }

    #[tokio::test]
557
    async fn test_validation_top_p() {
558
559
560
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
561
562
        let max_top_n_tokens = 4;
        let max_input_length = 5;
563
        let max_total_tokens = 106;
564
        let workers = 1;
565
566
567
568
569
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
570
            max_top_n_tokens,
571
572
573
574
575
576
577
578
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(1.0),
579
                    max_new_tokens: Some(5),
580
581
582
583
584
                    ..default_parameters()
                },
            })
            .await
        {
585
            Err(ValidationError::TopP) => (),
586
            _ => panic!("Unexpected top_p"),
587
588
        }

589
590
591
592
593
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(0.99),
594
                    max_new_tokens: Some(5),
595
596
597
598
599
                    ..default_parameters()
                },
            })
            .await
        {
600
            Ok(_) => (),
601
            _ => panic!("Unexpected top_p error"),
602
603
        }

604
605
606
607
608
        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: None,
609
                    max_new_tokens: Some(5),
610
611
612
613
614
                    ..default_parameters()
                },
            })
            .await
            .unwrap();
615
616
617
        // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
        assert_eq!(valid_request.parameters.top_p, 1.0);
    }
Nicolas Patry's avatar
Nicolas Patry committed
618
619
620
621
622
623
624
625

    #[tokio::test]
    async fn test_validation_top_n_tokens() {
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequences = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
626
        let max_total_tokens = 106;
Nicolas Patry's avatar
Nicolas Patry committed
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        let workers = 1;
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(5),
642
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
643
644
645
646
647
648
649
650
651
652
653
654
655
656
                    ..default_parameters()
                },
            })
            .await
        {
            Err(ValidationError::TopNTokens(4, 5)) => (),
            _ => panic!("Unexpected top_n_tokens"),
        }

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(4),
657
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
658
659
660
661
662
663
664
665
666
667
668
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(0),
669
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
670
671
672
673
674
675
676
677
678
679
680
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: None,
681
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
682
683
684
685
686
687
688
689
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        assert_eq!(valid_request.top_n_tokens, 0);
    }
690
}