validation.rs 20 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Payload validation logic
2
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
3
use crate::{GenerateParameters, GenerateRequest};
4
use rand::{thread_rng, Rng};
5
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use tokenizers::tokenizer::Tokenizer;
8
use tokenizers::TruncationDirection;
9
use tokio::sync::oneshot;
10
use tracing::{instrument, Span};
Olivier Dehaene's avatar
Olivier Dehaene committed
11

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
12
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
13
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
14
pub struct Validation {
15
    /// Validation parameters
16
    max_best_of: usize,
17
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
18
    max_top_n_tokens: u32,
19
20
21
22
    max_input_length: usize,
    max_total_tokens: usize,
    /// Channel to communicate with the background tokenization task
    sender: Option<flume::Sender<TokenizerRequest>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
23
24
25
}

impl Validation {
26
27
    pub(crate) fn new(
        workers: usize,
28
        tokenizer: Option<Tokenizer>,
29
        max_best_of: usize,
30
        max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
31
        max_top_n_tokens: u32,
32
33
34
        max_input_length: usize,
        max_total_tokens: usize,
    ) -> Self {
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        // If we have a fast tokenizer
        let sender = if let Some(tokenizer) = tokenizer {
            // Create channel
            let (validation_sender, validation_receiver) = flume::unbounded();

            // Create workers
            for _ in 0..workers {
                let tokenizer_clone = tokenizer.clone();
                let receiver_clone = validation_receiver.clone();

                // Spawn worker
                tokio::task::spawn_blocking(move || {
                    tokenizer_worker(tokenizer_clone, receiver_clone)
                });
            }
            Some(validation_sender)
        } else {
            None
        };

        Self {
            max_best_of,
            sender,
58
            max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
59
            max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
60
            max_input_length,
61
            max_total_tokens,
62
63
        }
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
64

65
66
67
68
69
    #[instrument(skip_all)]
    async fn validate_input(
        &self,
        inputs: String,
        truncate: Option<usize>,
70
71
        max_new_tokens: Option<u32>,
    ) -> Result<(String, usize, u32), ValidationError> {
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        // If we have a fast tokenizer
        if let Some(sender) = &self.sender {
            // Create response channel
            let (response_sender, response_receiver) = oneshot::channel();
            // Send request to the background validation task
            // Unwrap is safe here
            sender
                .send(((inputs, truncate), response_sender, Span::current()))
                .unwrap();

            // Await on response channel
            // Unwrap is safe here
            let (inputs, input_length) = response_receiver.await.unwrap()?;

            // Get total tokens
87
88
89
90
91
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
                self.max_total_tokens.saturating_sub(input_length) as u32
            };
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
            let total_tokens = input_length + max_new_tokens as usize;

            // Validate MaxTotalTokens
            if total_tokens > self.max_total_tokens {
                return Err(ValidationError::MaxTotalTokens(
                    self.max_total_tokens,
                    input_length,
                    max_new_tokens,
                ));
            }

            // Validate InputLength
            if input_length > self.max_input_length {
                return Err(ValidationError::InputLength(
                    self.max_input_length,
                    input_length,
                ));
            }

            metrics::histogram!("tgi_request_input_length", input_length as f64);
112
            Ok((inputs, input_length, max_new_tokens))
113
114
115
116
117
118
        }
        // Return inputs without validation
        else {
            // In this case, we don't know the real length in tokens of the inputs
            // However, the inputs will be truncated by the python servers
            // We make sure that truncate + max_new_tokens <= self.max_total_tokens
119
120
121
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
122
123
124
                if let Some(truncate) = truncate {
                    self.max_total_tokens.saturating_sub(truncate) as u32
                } else {
OlivierDehaene's avatar
OlivierDehaene committed
125
                    return Err(ValidationError::UnsetMaxNewTokens);
126
                }
127
            };
128
            let input_length = truncate.unwrap_or(self.max_input_length);
129
130

            // Validate MaxNewTokens
131
            if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
132
133
134
135
136
137
                return Err(ValidationError::MaxNewTokens(
                    self.max_total_tokens - self.max_input_length,
                    max_new_tokens,
                ));
            }

138
            Ok((inputs, input_length, max_new_tokens))
Olivier Dehaene's avatar
Olivier Dehaene committed
139
140
141
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
142
    /// Validate a payload and get the number of tokens in the input
143
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
144
145
146
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
147
    ) -> Result<ValidGenerateRequest, ValidationError> {
148
149
150
151
152
153
154
155
156
157
158
159
160
        let GenerateParameters {
            best_of,
            temperature,
            repetition_penalty,
            top_k,
            top_p,
            typical_p,
            do_sample,
            max_new_tokens,
            stop: stop_sequences,
            truncate,
            seed,
            watermark,
161
            decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
162
            top_n_tokens,
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
            ..
        } = request.parameters;

        // sampling must be true when best_of > 1
        let best_of = best_of.unwrap_or(1);
        let sampling = do_sample
            || temperature.is_some()
            || top_k.is_some()
            || top_p.is_some()
            || typical_p.is_some();

        if best_of > 1 && !sampling {
            return Err(BestOfSampling);
        }

        let temperature = temperature.unwrap_or(1.0);
        if temperature <= 0.0 {
            return Err(ValidationError::Temperature);
        }

        let repetition_penalty = repetition_penalty.unwrap_or(1.0);
        if repetition_penalty <= 0.0 {
            return Err(ValidationError::RepetitionPenalty);
        }

        // Different because the proto default value is not a valid value
        // for the user
        let top_p = top_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TopP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let typical_p = typical_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TypicalP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let top_k: u32 = top_k
            .map(|value| {
                if value <= 0 {
                    return Err(ValidationError::TopK);
                }
                Ok(value as u32)
            })
            .unwrap_or(Ok(0))?;

217
        if max_new_tokens == Some(0) {
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
            return Err(ValidationError::NegativeMaxNewTokens);
        }

        if stop_sequences.len() > self.max_stop_sequences {
            return Err(ValidationError::StopSequence(
                self.max_stop_sequences,
                stop_sequences.len(),
            ));
        }

        // If seed is None, assign a random one
        let seed = match seed {
            None => thread_rng().gen(),
            Some(seed) => {
                if best_of > 1 {
                    return Err(BestOfSeed);
                }
                seed
            }
        };

Nicolas Patry's avatar
Nicolas Patry committed
239
240
241
242
243
244
245
246
247
        let top_n_tokens = top_n_tokens
            .map(|value| {
                if value > self.max_top_n_tokens {
                    return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
                }
                Ok(value)
            })
            .unwrap_or(Ok(0))?;

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        // Check if inputs is empty
        if request.inputs.is_empty() {
            return Err(EmptyInput);
        }

        // Check if truncate is strictly positive and less than max_input_length
        let truncate = truncate
            .map(|value| {
                if value == 0 || value > self.max_input_length {
                    return Err(ValidationError::Truncate(self.max_input_length, value));
                }
                Ok(Some(value))
            })
            .unwrap_or(Ok(None))?;

        // Validate inputs
264
        let (inputs, input_length, max_new_tokens) = self
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
            .validate_input(request.inputs, truncate, max_new_tokens)
            .await?;

        let parameters = NextTokenChooserParameters {
            temperature,
            repetition_penalty,
            top_k,
            top_p,
            typical_p,
            do_sample,
            seed,
            watermark,
        };
        let stopping_parameters = StoppingCriteriaParameters {
            max_new_tokens,
            stop_sequences,
            ignore_eos_token: false,
        };

        metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

        Ok(ValidGenerateRequest {
            inputs,
288
            decoder_input_details,
289
            input_length: input_length as u32,
290
291
292
            truncate: truncate.unwrap_or(self.max_input_length) as u32,
            parameters,
            stopping_parameters,
Nicolas Patry's avatar
Nicolas Patry committed
293
            top_n_tokens,
294
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
295
    }
296
297
298
299
300
301
302
303
304
305
306
307
308
309

    /// Validate the best_of parameter
    #[instrument(skip_all)]
    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
        if self.max_best_of == 1 && best_of != 1 {
            return Err(ValidationError::BestOfDisabled);
        }

        if best_of > self.max_best_of {
            return Err(ValidationError::BestOf(self.max_best_of, best_of));
        }

        Ok(best_of)
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
310
311
}

312
313
/// Start tokenization workers
fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver<TokenizerRequest>) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
314
    // Loop over requests
315
    while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() {
316
317
        parent_span.in_scope(|| {
            response_tx
318
                .send(prepare_input(inputs, truncate, &tokenizer))
319
320
                .unwrap_or(())
        })
321
322
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
323

324
325
326
327
/// Get input length and optionally truncate it
fn prepare_input(
    inputs: String,
    truncate: Option<usize>,
328
    tokenizer: &Tokenizer,
329
) -> Result<(String, usize), ValidationError> {
330
331
    // Get the number of tokens in the input
    let mut encoding = tokenizer
332
        .encode(inputs.clone(), true)
333
334
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

335
336
    // Optionally truncate
    let (inputs, input_length) = match truncate {
337
338
        // Truncate is some and < encoding length
        Some(truncate) if truncate < encoding.len() => {
339
340
341
            // truncate encoding and decode new inputs
            encoding.truncate(truncate, 0, TruncationDirection::Left);
            let inputs = tokenizer
342
                .decode(encoding.get_ids(), false)
343
344
345
346
347
                .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
            (inputs, encoding.len())
        }
        // Nothing to do
        _ => (inputs, encoding.len()),
348
349
    };

350
    Ok((inputs, input_length))
Olivier Dehaene's avatar
Olivier Dehaene committed
351
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
352

353
354
355
type TokenizerRequest = (
    (String, Option<usize>),
    oneshot::Sender<Result<(String, usize), ValidationError>>,
356
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
357
358
);

359
360
361
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
362
    pub input_length: u32,
363
    pub truncate: u32,
364
    pub decoder_input_details: bool,
365
366
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
Nicolas Patry's avatar
Nicolas Patry committed
367
    pub top_n_tokens: u32,
368
369
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
370
371
#[derive(Error, Debug)]
pub enum ValidationError {
372
373
374
375
376
377
378
379
380
381
    #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
    BestOf(usize, usize),
    #[error("`best_of` != 1 is not allowed for this endpoint")]
    BestOfDisabled,
    #[error("you must use sampling when `best_of` is > 1")]
    BestOfSampling,
    #[error("`seed` must not be set when `best_of` > 1")]
    BestOfSeed,
    #[error("`best_of` != 1 is not supported when streaming tokens")]
    BestOfStream,
Nicolas Patry's avatar
Nicolas Patry committed
382
383
384
385
    #[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
    TopNTokens(u32, u32),
    #[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
    TopNTokensDisabled,
386
387
    #[error("`decoder_input_details` == true is not supported when streaming tokens")]
    PrefillDetailsStream,
388
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
389
    Temperature,
390
    #[error("`repetition_penalty` must be strictly positive")]
391
    RepetitionPenalty,
392
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
393
    TopP,
394
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
395
    TopK,
396
397
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
398
399
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
400
401
    #[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
    UnsetMaxNewTokens,
402
    #[error("`max_new_tokens` must be strictly positive")]
403
404
405
    NegativeMaxNewTokens,
    #[error("`max_new_tokens` must be <= {0}. Given: {1}")]
    MaxNewTokens(usize, u32),
406
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
407
    MaxTotalTokens(usize, usize, u32),
408
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
409
    InputLength(usize, usize),
410
    #[error("`inputs` cannot be empty")]
411
    EmptyInput,
412
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
413
    StopSequence(usize, usize),
414
415
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
416
}
417
418

#[cfg(test)]
419
mod tests {
420
    use super::*;
421
422
    use crate::default_parameters;
    use crate::tests::get_tokenizer;
423
424

    #[tokio::test]
425
    async fn test_validation_max_new_tokens() {
426
427
428
        let tokenizer = None;
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
429
430
431
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
432
        let workers = 1;
433
434
435
436
437
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
438
            max_top_n_tokens,
439
440
441
            max_input_length,
            max_total_tokens,
        );
442
443

        let max_new_tokens = 10;
444
        match validation
445
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
446
447
            .await
        {
448
            Err(ValidationError::MaxNewTokens(1, 10)) => (),
449
            _ => panic!("Unexpected not max new tokens"),
450
451
452
453
        }
    }

    #[tokio::test]
454
    async fn test_validation_input_length() {
455
456
457
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
458
459
460
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
461
        let workers = 1;
462
463
464
465
466
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
467
            max_top_n_tokens,
468
469
470
            max_input_length,
            max_total_tokens,
        );
471
472

        let max_new_tokens = 10;
473
        match validation
474
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
475
476
            .await
        {
Nicolas Patry's avatar
Nicolas Patry committed
477
            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
478
            _ => panic!("Unexpected not max new tokens"),
479
480
        }
    }
481
482

    #[tokio::test]
483
    async fn test_validation_best_of_sampling() {
484
485
486
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
487
488
489
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
490
        let workers = 1;
491
492
493
494
495
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
496
            max_top_n_tokens,
497
498
499
500
501
502
503
504
505
506
507
508
509
510
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    best_of: Some(2),
                    do_sample: false,
                    ..default_parameters()
                },
            })
            .await
        {
511
            Err(ValidationError::BestOfSampling) => (),
512
            _ => panic!("Unexpected not best of sampling"),
513
514
515
516
        }
    }

    #[tokio::test]
517
    async fn test_validation_top_p() {
518
519
520
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
521
522
523
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
524
        let workers = 1;
525
526
527
528
529
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
530
            max_top_n_tokens,
531
532
533
534
535
536
537
538
539
540
541
542
543
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(1.0),
                    ..default_parameters()
                },
            })
            .await
        {
544
            Err(ValidationError::TopP) => (),
545
            _ => panic!("Unexpected top_p"),
546
547
        }

548
549
550
551
552
553
554
555
556
557
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(0.99),
                    ..default_parameters()
                },
            })
            .await
        {
558
            Ok(_) => (),
559
            _ => panic!("Unexpected top_p error"),
560
561
        }

562
563
564
565
566
567
568
569
570
571
        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: None,
                    ..default_parameters()
                },
            })
            .await
            .unwrap();
572
573
574
        // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
        assert_eq!(valid_request.parameters.top_p, 1.0);
    }
Nicolas Patry's avatar
Nicolas Patry committed
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642

    #[tokio::test]
    async fn test_validation_top_n_tokens() {
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequences = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
        let workers = 1;
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(5),
                    ..default_parameters()
                },
            })
            .await
        {
            Err(ValidationError::TopNTokens(4, 5)) => (),
            _ => panic!("Unexpected top_n_tokens"),
        }

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(4),
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(0),
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: None,
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        assert_eq!(valid_request.top_n_tokens, 0);
    }
643
}