validation.rs 24.7 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Payload validation logic
2
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
drbh's avatar
drbh committed
3
use crate::{GenerateParameters, GenerateRequest, GrammarType};
4
use jsonschema::{Draft, JSONSchema};
5
use rand::{thread_rng, Rng};
6
use serde_json::Value;
drbh's avatar
drbh committed
7
8
9
use text_generation_client::{
    GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
11
use tokenizers::tokenizer::Tokenizer;
12
use tokenizers::TruncationDirection;
OlivierDehaene's avatar
OlivierDehaene committed
13
use tokio::sync::mpsc;
14
use tokio::sync::oneshot;
15
use tracing::{instrument, Span};
Olivier Dehaene's avatar
Olivier Dehaene committed
16

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
17
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
18
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
19
pub struct Validation {
20
    /// Validation parameters
21
    max_best_of: usize,
22
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
23
    max_top_n_tokens: u32,
24
25
    max_input_length: usize,
    max_total_tokens: usize,
drbh's avatar
drbh committed
26
    disable_grammar_support: bool,
27
    /// Channel to communicate with the background tokenization task
OlivierDehaene's avatar
OlivierDehaene committed
28
    sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
29
30
31
}

impl Validation {
OlivierDehaene's avatar
OlivierDehaene committed
32
    #[allow(clippy::too_many_arguments)]
33
34
    pub(crate) fn new(
        workers: usize,
35
        tokenizer: Option<Tokenizer>,
36
        max_best_of: usize,
37
        max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
38
        max_top_n_tokens: u32,
39
40
        max_input_length: usize,
        max_total_tokens: usize,
drbh's avatar
drbh committed
41
        disable_grammar_support: bool,
42
    ) -> Self {
43
44
        // If we have a fast tokenizer
        let sender = if let Some(tokenizer) = tokenizer {
OlivierDehaene's avatar
OlivierDehaene committed
45
46
47
            // Create round robin channel
            let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
            let mut senders = Vec::with_capacity(workers);
48
49
50
51

            // Create workers
            for _ in 0..workers {
                let tokenizer_clone = tokenizer.clone();
OlivierDehaene's avatar
OlivierDehaene committed
52
53
                let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
                senders.push(tokenizer_sender);
54
55
56

                // Spawn worker
                tokio::task::spawn_blocking(move || {
OlivierDehaene's avatar
OlivierDehaene committed
57
                    tokenizer_worker(tokenizer_clone, tokenizer_receiver)
58
59
                });
            }
OlivierDehaene's avatar
OlivierDehaene committed
60
61
62
63

            // Create tokenization round robin task
            tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));

64
65
66
67
68
69
70
71
            Some(validation_sender)
        } else {
            None
        };

        Self {
            max_best_of,
            sender,
72
            max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
73
            max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
74
            max_input_length,
75
            max_total_tokens,
drbh's avatar
drbh committed
76
            disable_grammar_support,
77
78
        }
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
79

80
    #[instrument(skip(self, inputs))]
81
    pub async fn tokenize(
82
83
84
        &self,
        inputs: String,
        truncate: Option<usize>,
85
    ) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
86
87
88
89
90
91
92
93
94
95
96
97
        // If we have a fast tokenizer
        if let Some(sender) = &self.sender {
            // Create response channel
            let (response_sender, response_receiver) = oneshot::channel();
            // Send request to the background validation task
            // Unwrap is safe here
            sender
                .send(((inputs, truncate), response_sender, Span::current()))
                .unwrap();

            // Await on response channel
            // Unwrap is safe here
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            let encoding = response_receiver.await.unwrap()?;
            Ok(Some(encoding))
        } else {
            Ok(None)
        }
    }

    #[instrument(skip(self, inputs))]
    async fn validate_input(
        &self,
        inputs: String,
        truncate: Option<usize>,
        max_new_tokens: Option<u32>,
    ) -> Result<(String, usize, u32), ValidationError> {
        // If we have a fast tokenizer
        if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
            // Create response channel
            let input_length = encoding.len();
116
117

            // Get total tokens
118
119
120
121
122
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
                self.max_total_tokens.saturating_sub(input_length) as u32
            };
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
            let total_tokens = input_length + max_new_tokens as usize;

            // Validate MaxTotalTokens
            if total_tokens > self.max_total_tokens {
                return Err(ValidationError::MaxTotalTokens(
                    self.max_total_tokens,
                    input_length,
                    max_new_tokens,
                ));
            }

            // Validate InputLength
            if input_length > self.max_input_length {
                return Err(ValidationError::InputLength(
                    self.max_input_length,
                    input_length,
                ));
            }

            metrics::histogram!("tgi_request_input_length", input_length as f64);
143
            Ok((inputs, input_length, max_new_tokens))
144
145
146
147
148
149
        }
        // Return inputs without validation
        else {
            // In this case, we don't know the real length in tokens of the inputs
            // However, the inputs will be truncated by the python servers
            // We make sure that truncate + max_new_tokens <= self.max_total_tokens
150
151
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
OlivierDehaene's avatar
OlivierDehaene committed
152
153
            } else if let Some(truncate) = truncate {
                self.max_total_tokens.saturating_sub(truncate) as u32
154
            } else {
OlivierDehaene's avatar
OlivierDehaene committed
155
                return Err(ValidationError::UnsetMaxNewTokens);
156
            };
157
            let input_length = truncate.unwrap_or(self.max_input_length);
158
159

            // Validate MaxNewTokens
160
            if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
161
162
163
164
165
166
                return Err(ValidationError::MaxNewTokens(
                    self.max_total_tokens - self.max_input_length,
                    max_new_tokens,
                ));
            }

167
            Ok((inputs, input_length, max_new_tokens))
Olivier Dehaene's avatar
Olivier Dehaene committed
168
169
170
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
171
    /// Validate a payload and get the number of tokens in the input
172
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
173
174
175
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
176
    ) -> Result<ValidGenerateRequest, ValidationError> {
177
178
179
180
        let GenerateParameters {
            best_of,
            temperature,
            repetition_penalty,
181
            frequency_penalty,
182
183
184
185
186
187
188
189
190
            top_k,
            top_p,
            typical_p,
            do_sample,
            max_new_tokens,
            stop: stop_sequences,
            truncate,
            seed,
            watermark,
191
            decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
192
            top_n_tokens,
drbh's avatar
drbh committed
193
            grammar,
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
            ..
        } = request.parameters;

        // sampling must be true when best_of > 1
        let best_of = best_of.unwrap_or(1);
        let sampling = do_sample
            || temperature.is_some()
            || top_k.is_some()
            || top_p.is_some()
            || typical_p.is_some();

        if best_of > 1 && !sampling {
            return Err(BestOfSampling);
        }

        let temperature = temperature.unwrap_or(1.0);
        if temperature <= 0.0 {
            return Err(ValidationError::Temperature);
        }

        let repetition_penalty = repetition_penalty.unwrap_or(1.0);
        if repetition_penalty <= 0.0 {
            return Err(ValidationError::RepetitionPenalty);
        }

219
220
221
222
223
        let frequency_penalty = frequency_penalty.unwrap_or(0.0);
        if !(-2.0..=2.0).contains(&frequency_penalty) {
            return Err(ValidationError::FrequencyPenalty);
        }

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        // Different because the proto default value is not a valid value
        // for the user
        let top_p = top_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TopP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let typical_p = typical_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TypicalP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let top_k: u32 = top_k
            .map(|value| {
                if value <= 0 {
                    return Err(ValidationError::TopK);
                }
                Ok(value as u32)
            })
            .unwrap_or(Ok(0))?;

253
        if max_new_tokens == Some(0) {
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
            return Err(ValidationError::NegativeMaxNewTokens);
        }

        if stop_sequences.len() > self.max_stop_sequences {
            return Err(ValidationError::StopSequence(
                self.max_stop_sequences,
                stop_sequences.len(),
            ));
        }

        // If seed is None, assign a random one
        let seed = match seed {
            None => thread_rng().gen(),
            Some(seed) => {
                if best_of > 1 {
                    return Err(BestOfSeed);
                }
                seed
            }
        };

Nicolas Patry's avatar
Nicolas Patry committed
275
276
277
278
279
280
281
282
283
        let top_n_tokens = top_n_tokens
            .map(|value| {
                if value > self.max_top_n_tokens {
                    return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
                }
                Ok(value)
            })
            .unwrap_or(Ok(0))?;

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        // Check if inputs is empty
        if request.inputs.is_empty() {
            return Err(EmptyInput);
        }

        // Check if truncate is strictly positive and less than max_input_length
        let truncate = truncate
            .map(|value| {
                if value == 0 || value > self.max_input_length {
                    return Err(ValidationError::Truncate(self.max_input_length, value));
                }
                Ok(Some(value))
            })
            .unwrap_or(Ok(None))?;

        // Validate inputs
300
        let (inputs, input_length, max_new_tokens) = self
301
302
303
            .validate_input(request.inputs, truncate, max_new_tokens)
            .await?;

drbh's avatar
drbh committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
        // NOTE: this is currently difficult because we need the tokenizer in Python to build
        // the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which
        // may be slow and memory intensive. Best case is to have a Rust implementation of the FSM
        // compiler and use that to build the FSM here.

        // Validate grammar and unpack the grammar and type for the proto message
        let (grammar, grammar_type) = match grammar {
            Some(grammar) => {
                // Ensure that grammar is not set if it's not supported
                if self.disable_grammar_support {
                    return Err(ValidationError::Grammar);
                }
                match grammar {
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
                    GrammarType::Json(json) => {
                        let json = match json {
                            // if value is a string, we need to parse it again to make sure its
                            // a valid json
                            Value::String(s) => serde_json::from_str(&s)
                                .map_err(|e| ValidationError::InvalidGrammar(e.to_string())),
                            Value::Object(_) => Ok(json),
                            _ => Err(ValidationError::Grammar),
                        }?;

                        // Check if the json is a valid JSONSchema
                        JSONSchema::options()
                            .with_draft(Draft::Draft202012)
                            .compile(&json)
                            .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;

                        (
                            // Serialize json to string
                            serde_json::to_string(&json)
                                .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
                            ProtoGrammarType::Json.into(),
                        )
                    }
drbh's avatar
drbh committed
341
342
343
344
345
346
                    GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
                }
            }
            None => (String::new(), ProtoGrammarType::None.into()),
        };

347
348
349
        let parameters = NextTokenChooserParameters {
            temperature,
            repetition_penalty,
350
            frequency_penalty,
351
352
353
354
355
356
            top_k,
            top_p,
            typical_p,
            do_sample,
            seed,
            watermark,
drbh's avatar
drbh committed
357
358
            grammar,
            grammar_type,
359
360
361
362
363
364
365
366
367
368
369
        };
        let stopping_parameters = StoppingCriteriaParameters {
            max_new_tokens,
            stop_sequences,
            ignore_eos_token: false,
        };

        metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

        Ok(ValidGenerateRequest {
            inputs,
370
            decoder_input_details,
371
            input_length: input_length as u32,
372
373
374
            truncate: truncate.unwrap_or(self.max_input_length) as u32,
            parameters,
            stopping_parameters,
Nicolas Patry's avatar
Nicolas Patry committed
375
            top_n_tokens,
376
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
377
    }
378
379
380
381
382
383
384
385
386
387
388
389
390
391

    /// Validate the best_of parameter
    #[instrument(skip_all)]
    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
        if self.max_best_of == 1 && best_of != 1 {
            return Err(ValidationError::BestOfDisabled);
        }

        if best_of > self.max_best_of {
            return Err(ValidationError::BestOf(self.max_best_of, best_of));
        }

        Ok(best_of)
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
392
393
}

OlivierDehaene's avatar
OlivierDehaene committed
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
/// Round robin tokenization task
async fn round_robin_task(
    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
    senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
) {
    loop {
        for sender in &senders {
            match receiver.recv().await {
                None => return,
                Some(request) => sender.send(request).unwrap(),
            };
        }
    }
}

409
/// Start tokenization workers
OlivierDehaene's avatar
OlivierDehaene committed
410
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
411
    // Loop over requests
OlivierDehaene's avatar
OlivierDehaene committed
412
    while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
413
414
        parent_span.in_scope(|| {
            response_tx
415
                .send(prepare_input(inputs, truncate, &tokenizer))
416
417
                .unwrap_or(())
        })
418
419
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
420

421
422
/// Get input length and optionally truncate it
fn prepare_input(
423
    mut inputs: String,
424
    truncate: Option<usize>,
425
    tokenizer: &Tokenizer,
426
) -> Result<(tokenizers::Encoding, String), ValidationError> {
427
428
    // Get the number of tokens in the input
    let mut encoding = tokenizer
429
        .encode(inputs.clone(), true)
430
431
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

432
    // Optionally truncate
433
434
    if let Some(truncate) = truncate {
        if truncate < encoding.len() {
435
            encoding.truncate(truncate, 0, TruncationDirection::Left);
436
            inputs = tokenizer
437
                .decode(encoding.get_ids(), false)
438
439
                .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
        }
440
    }
441

442
    Ok((encoding, inputs))
Olivier Dehaene's avatar
Olivier Dehaene committed
443
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
444

445
446
type TokenizerRequest = (
    (String, Option<usize>),
447
    oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
448
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
449
450
);

451
#[derive(Debug, Clone)]
452
453
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
454
    pub input_length: u32,
455
    pub truncate: u32,
456
    pub decoder_input_details: bool,
457
458
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
Nicolas Patry's avatar
Nicolas Patry committed
459
    pub top_n_tokens: u32,
460
461
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
462
463
#[derive(Error, Debug)]
pub enum ValidationError {
464
465
466
467
468
469
470
471
472
473
    #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
    BestOf(usize, usize),
    #[error("`best_of` != 1 is not allowed for this endpoint")]
    BestOfDisabled,
    #[error("you must use sampling when `best_of` is > 1")]
    BestOfSampling,
    #[error("`seed` must not be set when `best_of` > 1")]
    BestOfSeed,
    #[error("`best_of` != 1 is not supported when streaming tokens")]
    BestOfStream,
Nicolas Patry's avatar
Nicolas Patry committed
474
475
476
477
    #[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
    TopNTokens(u32, u32),
    #[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
    TopNTokensDisabled,
478
479
    #[error("`decoder_input_details` == true is not supported when streaming tokens")]
    PrefillDetailsStream,
480
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
481
    Temperature,
482
    #[error("`repetition_penalty` must be strictly positive")]
483
    RepetitionPenalty,
484
485
    #[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
    FrequencyPenalty,
486
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
487
    TopP,
488
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
489
    TopK,
490
491
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
492
493
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
494
495
    #[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
    UnsetMaxNewTokens,
496
    #[error("`max_new_tokens` must be strictly positive")]
497
498
499
    NegativeMaxNewTokens,
    #[error("`max_new_tokens` must be <= {0}. Given: {1}")]
    MaxNewTokens(usize, u32),
500
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
501
    MaxTotalTokens(usize, usize, u32),
502
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
503
    InputLength(usize, usize),
504
    #[error("`inputs` cannot be empty")]
505
    EmptyInput,
506
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
507
    StopSequence(usize, usize),
508
509
    #[error("tokenizer error {0}")]
    Tokenizer(String),
drbh's avatar
drbh committed
510
511
    #[error("grammar is not supported")]
    Grammar,
512
513
    #[error("grammar is not valid: {0}")]
    InvalidGrammar(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
514
}
515
516

#[cfg(test)]
517
mod tests {
518
    use super::*;
519
520
    use crate::default_parameters;
    use crate::tests::get_tokenizer;
521
522

    #[tokio::test]
523
    async fn test_validation_max_new_tokens() {
524
525
526
        let tokenizer = None;
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
527
528
529
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
530
        let workers = 1;
drbh's avatar
drbh committed
531
        let disable_grammar_support = true;
532
533
534
535
536
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
537
            max_top_n_tokens,
538
539
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
540
            disable_grammar_support,
541
        );
542
543

        let max_new_tokens = 10;
544
        match validation
545
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
546
547
            .await
        {
548
            Err(ValidationError::MaxNewTokens(1, 10)) => (),
549
            _ => panic!("Unexpected not max new tokens"),
550
551
552
553
        }
    }

    #[tokio::test]
554
    async fn test_validation_input_length() {
555
556
557
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
558
559
560
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
drbh's avatar
drbh committed
561
        let disable_grammar_support = true;
562
        let workers = 1;
563
564
565
566
567
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
568
            max_top_n_tokens,
569
570
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
571
            disable_grammar_support,
572
        );
573
574

        let max_new_tokens = 10;
575
        match validation
576
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
577
578
            .await
        {
Nicolas Patry's avatar
Nicolas Patry committed
579
            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
580
            _ => panic!("Unexpected not max new tokens"),
581
582
        }
    }
583
584

    #[tokio::test]
585
    async fn test_validation_best_of_sampling() {
586
587
588
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
589
590
591
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
592
        let workers = 1;
drbh's avatar
drbh committed
593
        let disable_grammar_support = true;
594
595
596
597
598
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
599
            max_top_n_tokens,
600
601
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
602
            disable_grammar_support,
603
604
605
606
607
608
609
610
611
612
613
614
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    best_of: Some(2),
                    do_sample: false,
                    ..default_parameters()
                },
            })
            .await
        {
615
            Err(ValidationError::BestOfSampling) => (),
616
            _ => panic!("Unexpected not best of sampling"),
617
618
619
620
        }
    }

    #[tokio::test]
621
    async fn test_validation_top_p() {
622
623
624
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
625
626
        let max_top_n_tokens = 4;
        let max_input_length = 5;
627
        let max_total_tokens = 106;
628
        let workers = 1;
drbh's avatar
drbh committed
629
        let disable_grammar_support = true;
630
631
632
633
634
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
635
            max_top_n_tokens,
636
637
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
638
            disable_grammar_support,
639
640
641
642
643
644
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(1.0),
645
                    max_new_tokens: Some(5),
646
647
648
649
650
                    ..default_parameters()
                },
            })
            .await
        {
651
            Err(ValidationError::TopP) => (),
652
            _ => panic!("Unexpected top_p"),
653
654
        }

655
656
657
658
659
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(0.99),
660
                    max_new_tokens: Some(5),
661
662
663
664
665
                    ..default_parameters()
                },
            })
            .await
        {
666
            Ok(_) => (),
667
            _ => panic!("Unexpected top_p error"),
668
669
        }

670
671
672
673
674
        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: None,
675
                    max_new_tokens: Some(5),
676
677
678
679
680
                    ..default_parameters()
                },
            })
            .await
            .unwrap();
681
682
683
        // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
        assert_eq!(valid_request.parameters.top_p, 1.0);
    }
Nicolas Patry's avatar
Nicolas Patry committed
684
685
686
687
688
689
690
691

    #[tokio::test]
    async fn test_validation_top_n_tokens() {
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequences = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
692
        let max_total_tokens = 106;
Nicolas Patry's avatar
Nicolas Patry committed
693
        let workers = 1;
drbh's avatar
drbh committed
694
        let disable_grammar_support = true;
Nicolas Patry's avatar
Nicolas Patry committed
695
696
697
698
699
700
701
702
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
drbh's avatar
drbh committed
703
            disable_grammar_support,
Nicolas Patry's avatar
Nicolas Patry committed
704
705
706
707
708
709
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(5),
710
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
711
712
713
714
715
716
717
718
719
720
721
722
723
724
                    ..default_parameters()
                },
            })
            .await
        {
            Err(ValidationError::TopNTokens(4, 5)) => (),
            _ => panic!("Unexpected top_n_tokens"),
        }

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(4),
725
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
726
727
728
729
730
731
732
733
734
735
736
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(0),
737
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
738
739
740
741
742
743
744
745
746
747
748
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: None,
749
                    max_new_tokens: Some(5),
Nicolas Patry's avatar
Nicolas Patry committed
750
751
752
753
754
755
756
757
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        assert_eq!(valid_request.top_n_tokens, 0);
    }
758
}