validation.rs 19.7 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Payload validation logic
2
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
3
use crate::{GenerateParameters, GenerateRequest};
4
use rand::{thread_rng, Rng};
5
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use tokenizers::tokenizer::Tokenizer;
8
use tokenizers::TruncationDirection;
9
use tokio::sync::oneshot;
10
use tracing::{instrument, Span};
Olivier Dehaene's avatar
Olivier Dehaene committed
11

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
12
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
13
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
14
pub struct Validation {
15
    /// Validation parameters
16
    max_best_of: usize,
17
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
18
    max_top_n_tokens: u32,
19
20
21
22
    max_input_length: usize,
    max_total_tokens: usize,
    /// Channel to communicate with the background tokenization task
    sender: Option<flume::Sender<TokenizerRequest>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
23
24
25
}

impl Validation {
26
27
    pub(crate) fn new(
        workers: usize,
28
        tokenizer: Option<Tokenizer>,
29
        max_best_of: usize,
30
        max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
31
        max_top_n_tokens: u32,
32
33
34
        max_input_length: usize,
        max_total_tokens: usize,
    ) -> Self {
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        // If we have a fast tokenizer
        let sender = if let Some(tokenizer) = tokenizer {
            // Create channel
            let (validation_sender, validation_receiver) = flume::unbounded();

            // Create workers
            for _ in 0..workers {
                let tokenizer_clone = tokenizer.clone();
                let receiver_clone = validation_receiver.clone();

                // Spawn worker
                tokio::task::spawn_blocking(move || {
                    tokenizer_worker(tokenizer_clone, receiver_clone)
                });
            }
            Some(validation_sender)
        } else {
            None
        };

        Self {
            max_best_of,
            sender,
58
            max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
59
            max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
60
            max_input_length,
61
            max_total_tokens,
62
63
        }
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
64

65
66
67
68
69
    #[instrument(skip_all)]
    async fn validate_input(
        &self,
        inputs: String,
        truncate: Option<usize>,
70
71
        max_new_tokens: Option<u32>,
    ) -> Result<(String, usize, u32), ValidationError> {
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        // If we have a fast tokenizer
        if let Some(sender) = &self.sender {
            // Create response channel
            let (response_sender, response_receiver) = oneshot::channel();
            // Send request to the background validation task
            // Unwrap is safe here
            sender
                .send(((inputs, truncate), response_sender, Span::current()))
                .unwrap();

            // Await on response channel
            // Unwrap is safe here
            let (inputs, input_length) = response_receiver.await.unwrap()?;

            // Get total tokens
87
88
89
90
91
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
                self.max_total_tokens.saturating_sub(input_length) as u32
            };
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
            let total_tokens = input_length + max_new_tokens as usize;

            // Validate MaxTotalTokens
            if total_tokens > self.max_total_tokens {
                return Err(ValidationError::MaxTotalTokens(
                    self.max_total_tokens,
                    input_length,
                    max_new_tokens,
                ));
            }

            // Validate InputLength
            if input_length > self.max_input_length {
                return Err(ValidationError::InputLength(
                    self.max_input_length,
                    input_length,
                ));
            }

            metrics::histogram!("tgi_request_input_length", input_length as f64);
112
            Ok((inputs, input_length, max_new_tokens))
113
114
115
116
117
118
        }
        // Return inputs without validation
        else {
            // In this case, we don't know the real length in tokens of the inputs
            // However, the inputs will be truncated by the python servers
            // We make sure that truncate + max_new_tokens <= self.max_total_tokens
119
            let input_length = truncate.unwrap_or(self.max_input_length);
120
121
122
123
124
            let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
                max_new_tokens
            } else {
                self.max_total_tokens.saturating_sub(input_length) as u32
            };
125
126

            // Validate MaxNewTokens
127
            if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
128
129
130
131
132
133
                return Err(ValidationError::MaxNewTokens(
                    self.max_total_tokens - self.max_input_length,
                    max_new_tokens,
                ));
            }

134
            Ok((inputs, input_length, max_new_tokens))
Olivier Dehaene's avatar
Olivier Dehaene committed
135
136
137
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
138
    /// Validate a payload and get the number of tokens in the input
139
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
140
141
142
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
143
    ) -> Result<ValidGenerateRequest, ValidationError> {
144
145
146
147
148
149
150
151
152
153
154
155
156
        let GenerateParameters {
            best_of,
            temperature,
            repetition_penalty,
            top_k,
            top_p,
            typical_p,
            do_sample,
            max_new_tokens,
            stop: stop_sequences,
            truncate,
            seed,
            watermark,
157
            decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
158
            top_n_tokens,
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
            ..
        } = request.parameters;

        // sampling must be true when best_of > 1
        let best_of = best_of.unwrap_or(1);
        let sampling = do_sample
            || temperature.is_some()
            || top_k.is_some()
            || top_p.is_some()
            || typical_p.is_some();

        if best_of > 1 && !sampling {
            return Err(BestOfSampling);
        }

        let temperature = temperature.unwrap_or(1.0);
        if temperature <= 0.0 {
            return Err(ValidationError::Temperature);
        }

        let repetition_penalty = repetition_penalty.unwrap_or(1.0);
        if repetition_penalty <= 0.0 {
            return Err(ValidationError::RepetitionPenalty);
        }

        // Different because the proto default value is not a valid value
        // for the user
        let top_p = top_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TopP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let typical_p = typical_p
            .map(|value| {
                if value <= 0.0 || value >= 1.0 {
                    return Err(ValidationError::TypicalP);
                }
                Ok(value)
            })
            .unwrap_or(Ok(1.0))?;

        let top_k: u32 = top_k
            .map(|value| {
                if value <= 0 {
                    return Err(ValidationError::TopK);
                }
                Ok(value as u32)
            })
            .unwrap_or(Ok(0))?;

213
        if max_new_tokens == Some(0) {
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
            return Err(ValidationError::NegativeMaxNewTokens);
        }

        if stop_sequences.len() > self.max_stop_sequences {
            return Err(ValidationError::StopSequence(
                self.max_stop_sequences,
                stop_sequences.len(),
            ));
        }

        // If seed is None, assign a random one
        let seed = match seed {
            None => thread_rng().gen(),
            Some(seed) => {
                if best_of > 1 {
                    return Err(BestOfSeed);
                }
                seed
            }
        };

Nicolas Patry's avatar
Nicolas Patry committed
235
236
237
238
239
240
241
242
243
        let top_n_tokens = top_n_tokens
            .map(|value| {
                if value > self.max_top_n_tokens {
                    return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
                }
                Ok(value)
            })
            .unwrap_or(Ok(0))?;

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        // Check if inputs is empty
        if request.inputs.is_empty() {
            return Err(EmptyInput);
        }

        // Check if truncate is strictly positive and less than max_input_length
        let truncate = truncate
            .map(|value| {
                if value == 0 || value > self.max_input_length {
                    return Err(ValidationError::Truncate(self.max_input_length, value));
                }
                Ok(Some(value))
            })
            .unwrap_or(Ok(None))?;

        // Validate inputs
260
        let (inputs, input_length, max_new_tokens) = self
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
            .validate_input(request.inputs, truncate, max_new_tokens)
            .await?;

        let parameters = NextTokenChooserParameters {
            temperature,
            repetition_penalty,
            top_k,
            top_p,
            typical_p,
            do_sample,
            seed,
            watermark,
        };
        let stopping_parameters = StoppingCriteriaParameters {
            max_new_tokens,
            stop_sequences,
            ignore_eos_token: false,
        };

        metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

        Ok(ValidGenerateRequest {
            inputs,
284
            decoder_input_details,
285
            input_length: input_length as u32,
286
287
288
            truncate: truncate.unwrap_or(self.max_input_length) as u32,
            parameters,
            stopping_parameters,
Nicolas Patry's avatar
Nicolas Patry committed
289
            top_n_tokens,
290
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
291
    }
292
293
294
295
296
297
298
299
300
301
302
303
304
305

    /// Validate the best_of parameter
    #[instrument(skip_all)]
    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
        if self.max_best_of == 1 && best_of != 1 {
            return Err(ValidationError::BestOfDisabled);
        }

        if best_of > self.max_best_of {
            return Err(ValidationError::BestOf(self.max_best_of, best_of));
        }

        Ok(best_of)
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
306
307
}

308
309
/// Start tokenization workers
fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver<TokenizerRequest>) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
310
    // Loop over requests
311
    while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() {
312
313
        parent_span.in_scope(|| {
            response_tx
314
                .send(prepare_input(inputs, truncate, &tokenizer))
315
316
                .unwrap_or(())
        })
317
318
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
319

320
321
322
323
/// Get input length and optionally truncate it
fn prepare_input(
    inputs: String,
    truncate: Option<usize>,
324
    tokenizer: &Tokenizer,
325
) -> Result<(String, usize), ValidationError> {
326
327
    // Get the number of tokens in the input
    let mut encoding = tokenizer
328
        .encode(inputs.clone(), true)
329
330
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

331
332
    // Optionally truncate
    let (inputs, input_length) = match truncate {
333
334
        // Truncate is some and < encoding length
        Some(truncate) if truncate < encoding.len() => {
335
336
337
            // truncate encoding and decode new inputs
            encoding.truncate(truncate, 0, TruncationDirection::Left);
            let inputs = tokenizer
338
                .decode(encoding.get_ids(), false)
339
340
341
342
343
                .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
            (inputs, encoding.len())
        }
        // Nothing to do
        _ => (inputs, encoding.len()),
344
345
    };

346
    Ok((inputs, input_length))
Olivier Dehaene's avatar
Olivier Dehaene committed
347
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
348

349
350
351
type TokenizerRequest = (
    (String, Option<usize>),
    oneshot::Sender<Result<(String, usize), ValidationError>>,
352
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
353
354
);

355
356
357
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
358
    pub input_length: u32,
359
    pub truncate: u32,
360
    pub decoder_input_details: bool,
361
362
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
Nicolas Patry's avatar
Nicolas Patry committed
363
    pub top_n_tokens: u32,
364
365
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
366
367
#[derive(Error, Debug)]
pub enum ValidationError {
368
369
370
371
372
373
374
375
376
377
    #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
    BestOf(usize, usize),
    #[error("`best_of` != 1 is not allowed for this endpoint")]
    BestOfDisabled,
    #[error("you must use sampling when `best_of` is > 1")]
    BestOfSampling,
    #[error("`seed` must not be set when `best_of` > 1")]
    BestOfSeed,
    #[error("`best_of` != 1 is not supported when streaming tokens")]
    BestOfStream,
Nicolas Patry's avatar
Nicolas Patry committed
378
379
380
381
    #[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
    TopNTokens(u32, u32),
    #[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
    TopNTokensDisabled,
382
383
    #[error("`decoder_input_details` == true is not supported when streaming tokens")]
    PrefillDetailsStream,
384
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
385
    Temperature,
386
    #[error("`repetition_penalty` must be strictly positive")]
387
    RepetitionPenalty,
388
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
389
    TopP,
390
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
391
    TopK,
392
393
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
394
395
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
396
    #[error("`max_new_tokens` must be strictly positive")]
397
398
399
    NegativeMaxNewTokens,
    #[error("`max_new_tokens` must be <= {0}. Given: {1}")]
    MaxNewTokens(usize, u32),
400
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
401
    MaxTotalTokens(usize, usize, u32),
402
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
403
    InputLength(usize, usize),
404
    #[error("`inputs` cannot be empty")]
405
    EmptyInput,
406
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
407
    StopSequence(usize, usize),
408
409
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
410
}
411
412

#[cfg(test)]
413
mod tests {
414
    use super::*;
415
416
    use crate::default_parameters;
    use crate::tests::get_tokenizer;
417
418

    #[tokio::test]
419
    async fn test_validation_max_new_tokens() {
420
421
422
        let tokenizer = None;
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
423
424
425
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
426
        let workers = 1;
427
428
429
430
431
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
432
            max_top_n_tokens,
433
434
435
            max_input_length,
            max_total_tokens,
        );
436
437

        let max_new_tokens = 10;
438
        match validation
439
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
440
441
            .await
        {
442
            Err(ValidationError::MaxNewTokens(1, 10)) => (),
443
            _ => panic!("Unexpected not max new tokens"),
444
445
446
447
        }
    }

    #[tokio::test]
448
    async fn test_validation_input_length() {
449
450
451
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
452
453
454
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
455
        let workers = 1;
456
457
458
459
460
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
461
            max_top_n_tokens,
462
463
464
            max_input_length,
            max_total_tokens,
        );
465
466

        let max_new_tokens = 10;
467
        match validation
468
            .validate_input("Hello".to_string(), None, Some(max_new_tokens))
469
470
            .await
        {
Nicolas Patry's avatar
Nicolas Patry committed
471
            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
472
            _ => panic!("Unexpected not max new tokens"),
473
474
        }
    }
475
476

    #[tokio::test]
477
    async fn test_validation_best_of_sampling() {
478
479
480
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
481
482
483
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
484
        let workers = 1;
485
486
487
488
489
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
490
            max_top_n_tokens,
491
492
493
494
495
496
497
498
499
500
501
502
503
504
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    best_of: Some(2),
                    do_sample: false,
                    ..default_parameters()
                },
            })
            .await
        {
505
            Err(ValidationError::BestOfSampling) => (),
506
            _ => panic!("Unexpected not best of sampling"),
507
508
509
510
        }
    }

    #[tokio::test]
511
    async fn test_validation_top_p() {
512
513
514
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequence = 3;
Nicolas Patry's avatar
Nicolas Patry committed
515
516
517
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
518
        let workers = 1;
519
520
521
522
523
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequence,
Nicolas Patry's avatar
Nicolas Patry committed
524
            max_top_n_tokens,
525
526
527
528
529
530
531
532
533
534
535
536
537
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(1.0),
                    ..default_parameters()
                },
            })
            .await
        {
538
            Err(ValidationError::TopP) => (),
539
            _ => panic!("Unexpected top_p"),
540
541
        }

542
543
544
545
546
547
548
549
550
551
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: Some(0.99),
                    ..default_parameters()
                },
            })
            .await
        {
552
            Ok(_) => (),
553
            _ => panic!("Unexpected top_p error"),
554
555
        }

556
557
558
559
560
561
562
563
564
565
        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_p: None,
                    ..default_parameters()
                },
            })
            .await
            .unwrap();
566
567
568
        // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
        assert_eq!(valid_request.parameters.top_p, 1.0);
    }
Nicolas Patry's avatar
Nicolas Patry committed
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636

    #[tokio::test]
    async fn test_validation_top_n_tokens() {
        let tokenizer = Some(get_tokenizer().await);
        let max_best_of = 2;
        let max_stop_sequences = 3;
        let max_top_n_tokens = 4;
        let max_input_length = 5;
        let max_total_tokens = 6;
        let workers = 1;
        let validation = Validation::new(
            workers,
            tokenizer,
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_length,
            max_total_tokens,
        );
        match validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(5),
                    ..default_parameters()
                },
            })
            .await
        {
            Err(ValidationError::TopNTokens(4, 5)) => (),
            _ => panic!("Unexpected top_n_tokens"),
        }

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(4),
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: Some(0),
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        let valid_request = validation
            .validate(GenerateRequest {
                inputs: "Hello".to_string(),
                parameters: GenerateParameters {
                    top_n_tokens: None,
                    ..default_parameters()
                },
            })
            .await
            .unwrap();

        assert_eq!(valid_request.top_n_tokens, 0);
    }
637
}