validation.rs 11.1 KB
Newer Older
1
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
/// Payload validation logic
3
use crate::{GenerateParameters, GenerateRequest};
4
5
use rand::rngs::ThreadRng;
use rand::Rng;
6
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
8
use tokenizers::tokenizer::Tokenizer;
9
use tokenizers::TruncationDirection;
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use tokio::sync::{mpsc, oneshot};
11
use tracing::{instrument, Span};
Olivier Dehaene's avatar
Olivier Dehaene committed
12

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
14
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
15
pub struct Validation {
16
17
18
    /// maximum value for the best_of parameter
    #[allow(dead_code)]
    max_best_of: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19
    /// Channel to communicate with the background validation task
20
    sender: mpsc::UnboundedSender<ValidationRequest>,
Olivier Dehaene's avatar
Olivier Dehaene committed
21
22
23
}

impl Validation {
24
25
26
    pub(crate) fn new(
        workers: usize,
        tokenizer: Tokenizer,
27
        max_best_of: usize,
28
29
30
31
        max_stop_sequences: usize,
        max_input_length: usize,
        max_total_tokens: usize,
    ) -> Self {
32
        // Create channel
33
        let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
Olivier Dehaene's avatar
Olivier Dehaene committed
34

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
36
37
38
        // Launch background validation task
        tokio::spawn(validation_task(
            workers,
            tokenizer,
39
            max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
40
            max_input_length,
41
            max_total_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
42
43
            validation_receiver,
        ));
Olivier Dehaene's avatar
Olivier Dehaene committed
44
45

        Self {
46
            max_best_of,
Olivier Dehaene's avatar
Olivier Dehaene committed
47
48
49
50
            sender: validation_sender,
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
51
    /// Validate a payload and get the number of tokens in the input
52
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
53
54
55
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
56
    ) -> Result<ValidGenerateRequest, ValidationError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
57
        // Create response channel
Olivier Dehaene's avatar
Olivier Dehaene committed
58
        let (sender, receiver) = oneshot::channel();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
59
60
        // Send request to the background validation task
        // Unwrap is safe here
61
62
63
        self.sender
            .send((request, sender, Span::current()))
            .unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
64
65
        // Await on response channel
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
66
67
        receiver.await.unwrap()
    }
68
69
70
71
72
73
74
75
76
77
78
79
80
81

    /// Validate the best_of parameter
    #[instrument(skip_all)]
    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
        if self.max_best_of == 1 && best_of != 1 {
            return Err(ValidationError::BestOfDisabled);
        }

        if best_of > self.max_best_of {
            return Err(ValidationError::BestOf(self.max_best_of, best_of));
        }

        Ok(best_of)
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
82
83
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
84
85
86
87
88
/// Validation task
/// Load balance the validation requests between multiple validation workers
async fn validation_task(
    workers: usize,
    tokenizer: Tokenizer,
89
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
90
    max_input_length: usize,
91
    max_total_tokens: usize,
92
    mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
93
94
95
96
97
) {
    let mut workers_senders = Vec::with_capacity(workers);

    // Create workers
    for _ in 0..workers {
98
        let tokenizer_clone: Tokenizer = tokenizer.clone().into();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
99
100
101
102
103
104
        // Create channel to communicate with worker
        let (worker_sender, worker_receiver) = mpsc::channel(workers);
        workers_senders.push(worker_sender);

        // Spawn worker
        tokio::task::spawn_blocking(move || {
105
106
107
108
109
110
111
            validation_worker(
                tokenizer_clone,
                max_stop_sequences,
                max_input_length,
                max_total_tokens,
                worker_receiver,
            )
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        });
    }

    loop {
        // Load balance requests between workers
        for sender in workers_senders.iter() {
            if let Some(validation_request) = receiver.recv().await {
                sender.send(validation_request).await.unwrap();
            } else {
                return;
            }
        }
    }
}

/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn validation_worker(
130
    tokenizer: Tokenizer,
131
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
132
    max_input_length: usize,
133
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
134
135
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
136
137
138
    // Seed rng
    let mut rng = rand::thread_rng();

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
139
    // Loop over requests
140
141
142
143
    while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
        parent_span.in_scope(|| {
            response_tx
                .send(
144
145
146
147
148
149
150
151
152
                    validate(
                        request,
                        &tokenizer,
                        max_stop_sequences,
                        max_input_length,
                        max_total_tokens,
                        &mut rng,
                    )
                    .map_err(|err| {
153
                        metrics::increment_counter!("tgi_request_failure", "err" => "validation");
154
155
156
157
158
159
                        tracing::error!("{err}");
                        err
                    }),
                )
                .unwrap_or(())
        })
160
161
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
162

163
fn validate(
164
    request: GenerateRequest,
165
    tokenizer: &Tokenizer,
166
    max_stop_sequences: usize,
167
    max_input_length: usize,
168
    max_total_tokens: usize,
169
    rng: &mut ThreadRng,
170
) -> Result<ValidGenerateRequest, ValidationError> {
171
    let GenerateParameters {
172
        best_of,
173
174
175
176
        temperature,
        repetition_penalty,
        top_k,
        top_p,
177
        typical_p,
178
179
180
        do_sample,
        max_new_tokens,
        stop: stop_sequences,
181
        truncate,
182
        seed,
183
        watermark,
184
185
186
        ..
    } = request.parameters;

187
188
189
190
191
192
193
194
195
196
197
198
    // sampling must be true when best_of > 1
    let best_of = best_of.unwrap_or(1);
    let sampling = do_sample
        || temperature.is_some()
        || top_k.is_some()
        || top_p.is_some()
        || typical_p.is_some();

    if best_of > 1 && !sampling {
        return Err(BestOfSampling);
    }

199
200
    let temperature = temperature.unwrap_or(1.0);
    if temperature <= 0.0 {
201
202
        return Err(ValidationError::Temperature);
    }
203
204
205

    let repetition_penalty = repetition_penalty.unwrap_or(1.0);
    if repetition_penalty <= 0.0 {
206
207
        return Err(ValidationError::RepetitionPenalty);
    }
208

209
    // Different because the proto default value is not a valid value
210
    // for the user
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    let top_p = top_p
        .map(|value| {
            if value <= 0.0 || value >= 1.0 {
                return Err(ValidationError::TopP);
            }
            Ok(value)
        })
        .unwrap_or(Ok(1.0))?;

    let typical_p = typical_p
        .map(|value| {
            if value <= 0.0 || value >= 1.0 {
                return Err(ValidationError::TypicalP);
            }
            Ok(value)
        })
        .unwrap_or(Ok(1.0))?;

    let top_k: u32 = top_k
        .map(|value| {
            if value <= 0 {
232
233
                return Err(ValidationError::TopK);
            }
234
235
236
            Ok(value as u32)
        })
        .unwrap_or(Ok(0))?;
237

238
239
    if max_new_tokens == 0 {
        return Err(ValidationError::MaxNewTokens);
240
    }
241

242
    if stop_sequences.len() > max_stop_sequences {
243
        return Err(ValidationError::StopSequence(
244
            max_stop_sequences,
245
            stop_sequences.len(),
246
        ));
247
248
    }

249
    // If seed is None, assign a random one
250
    let seed = match seed {
251
        None => rng.gen(),
252
253
254
255
256
257
        Some(seed) => {
            if best_of > 1 {
                return Err(BestOfSeed);
            }
            seed
        }
258
    };
259

260
261
262
263
264
    // Check if inputs is empty
    if request.inputs.is_empty() {
        return Err(EmptyInput);
    }

265
266
267
268
269
    // Check if truncate is strictly positive and less than max_input_length
    let truncate = truncate
        .map(|value| {
            if value == 0 || value > max_input_length {
                return Err(ValidationError::Truncate(max_input_length, value));
270
            }
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
            Ok(Some(value))
        })
        .unwrap_or(Ok(None))?;

    // Get the number of tokens in the input
    let mut encoding = tokenizer
        .encode(request.inputs.clone(), true)
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

    let (inputs, input_length) = if let Some(truncate) = truncate {
        // truncate encoding and decode new inputs
        encoding.truncate(truncate, 0, TruncationDirection::Left);
        let inputs = tokenizer
            .decode(Vec::from(encoding.get_ids()), false)
            .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
        (inputs, encoding.len())
    } else {
        (request.inputs, encoding.len())
    };

    if input_length > max_input_length {
        return Err(ValidationError::InputLength(max_input_length, input_length));
    }

    let total_tokens = input_length + max_new_tokens as usize;
    if total_tokens > max_total_tokens {
        return Err(ValidationError::MaxTotalTokens(
            max_total_tokens,
            input_length,
            max_new_tokens,
        ));
Olivier Dehaene's avatar
Olivier Dehaene committed
302
    }
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328

    // Return ValidGenerateRequest
    let parameters = NextTokenChooserParameters {
        temperature,
        repetition_penalty,
        top_k,
        top_p,
        typical_p,
        do_sample,
        seed,
        watermark,
    };
    let stopping_parameters = StoppingCriteriaParameters {
        max_new_tokens,
        stop_sequences,
    };

    metrics::histogram!("tgi_request_input_length", input_length as f64);
    metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

    Ok(ValidGenerateRequest {
        inputs,
        input_length: input_length as u32,
        parameters,
        stopping_parameters,
    })
Olivier Dehaene's avatar
Olivier Dehaene committed
329
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
330
331
332

type ValidationRequest = (
    GenerateRequest,
333
    oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
334
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
335
336
);

337
338
339
340
341
342
343
344
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
    pub input_length: u32,
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
345
346
#[derive(Error, Debug)]
pub enum ValidationError {
347
348
349
350
351
352
353
354
355
356
    #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
    BestOf(usize, usize),
    #[error("`best_of` != 1 is not allowed for this endpoint")]
    BestOfDisabled,
    #[error("you must use sampling when `best_of` is > 1")]
    BestOfSampling,
    #[error("`seed` must not be set when `best_of` > 1")]
    BestOfSeed,
    #[error("`best_of` != 1 is not supported when streaming tokens")]
    BestOfStream,
357
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
358
    Temperature,
359
    #[error("`repetition_penalty` must be strictly positive")]
360
    RepetitionPenalty,
361
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
362
    TopP,
363
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
364
    TopK,
365
366
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
367
368
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
369
    #[error("`max_new_tokens` must be strictly positive")]
370
    MaxNewTokens,
371
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
372
    MaxTotalTokens(usize, usize, u32),
373
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
374
    InputLength(usize, usize),
375
    #[error("`inputs` cannot be empty")]
376
    EmptyInput,
377
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
378
    StopSequence(usize, usize),
379
380
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
381
}