validation.rs 9.63 KB
Newer Older
1
use crate::validation::ValidationError::EmptyInput;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
/// Payload validation logic
3
use crate::{GenerateParameters, GenerateRequest};
4
5
use rand::rngs::ThreadRng;
use rand::Rng;
6
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
8
use tokenizers::tokenizer::Tokenizer;
9
use tokenizers::TruncationDirection;
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use tokio::sync::{mpsc, oneshot};
11
use tracing::{instrument, Span};
Olivier Dehaene's avatar
Olivier Dehaene committed
12

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
14
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
15
pub struct Validation {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
16
    /// Channel to communicate with the background validation task
17
    sender: mpsc::UnboundedSender<ValidationRequest>,
Olivier Dehaene's avatar
Olivier Dehaene committed
18
19
20
}

impl Validation {
21
22
23
24
25
26
27
    pub(crate) fn new(
        workers: usize,
        tokenizer: Tokenizer,
        max_stop_sequences: usize,
        max_input_length: usize,
        max_total_tokens: usize,
    ) -> Self {
28
        // Create channel
29
        let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
Olivier Dehaene's avatar
Olivier Dehaene committed
30

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
31
32
33
34
        // Launch background validation task
        tokio::spawn(validation_task(
            workers,
            tokenizer,
35
            max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
36
            max_input_length,
37
            max_total_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
38
39
            validation_receiver,
        ));
Olivier Dehaene's avatar
Olivier Dehaene committed
40
41
42
43
44
45

        Self {
            sender: validation_sender,
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
46
    /// Validate a payload and get the number of tokens in the input
47
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
48
49
50
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
51
    ) -> Result<ValidGenerateRequest, ValidationError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
52
        // Create response channel
Olivier Dehaene's avatar
Olivier Dehaene committed
53
        let (sender, receiver) = oneshot::channel();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
54
55
        // Send request to the background validation task
        // Unwrap is safe here
56
57
58
        self.sender
            .send((request, sender, Span::current()))
            .unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
59
60
        // Await on response channel
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
61
62
63
64
        receiver.await.unwrap()
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
65
66
67
68
69
/// Validation task
/// Load balance the validation requests between multiple validation workers
async fn validation_task(
    workers: usize,
    tokenizer: Tokenizer,
70
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
71
    max_input_length: usize,
72
    max_total_tokens: usize,
73
    mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
74
75
76
77
78
) {
    let mut workers_senders = Vec::with_capacity(workers);

    // Create workers
    for _ in 0..workers {
79
        let tokenizer_clone: Tokenizer = tokenizer.clone().into();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
80
81
82
83
84
85
        // Create channel to communicate with worker
        let (worker_sender, worker_receiver) = mpsc::channel(workers);
        workers_senders.push(worker_sender);

        // Spawn worker
        tokio::task::spawn_blocking(move || {
86
87
88
89
90
91
92
            validation_worker(
                tokenizer_clone,
                max_stop_sequences,
                max_input_length,
                max_total_tokens,
                worker_receiver,
            )
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        });
    }

    loop {
        // Load balance requests between workers
        for sender in workers_senders.iter() {
            if let Some(validation_request) = receiver.recv().await {
                sender.send(validation_request).await.unwrap();
            } else {
                return;
            }
        }
    }
}

/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn validation_worker(
111
    tokenizer: Tokenizer,
112
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
113
    max_input_length: usize,
114
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
115
116
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
117
118
119
    // Seed rng
    let mut rng = rand::thread_rng();

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
120
    // Loop over requests
121
122
123
124
    while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
        parent_span.in_scope(|| {
            response_tx
                .send(
125
126
127
128
129
130
131
132
133
                    validate(
                        request,
                        &tokenizer,
                        max_stop_sequences,
                        max_input_length,
                        max_total_tokens,
                        &mut rng,
                    )
                    .map_err(|err| {
134
                        metrics::increment_counter!("tgi_request_failure", "err" => "validation");
135
136
137
138
139
140
                        tracing::error!("{err}");
                        err
                    }),
                )
                .unwrap_or(())
        })
141
142
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
143

144
fn validate(
145
    request: GenerateRequest,
146
    tokenizer: &Tokenizer,
147
    max_stop_sequences: usize,
148
    max_input_length: usize,
149
    max_total_tokens: usize,
150
    rng: &mut ThreadRng,
151
) -> Result<ValidGenerateRequest, ValidationError> {
152
153
154
155
156
    let GenerateParameters {
        temperature,
        repetition_penalty,
        top_k,
        top_p,
157
        typical_p,
158
159
160
        do_sample,
        max_new_tokens,
        stop: stop_sequences,
161
        truncate,
162
        seed,
163
        watermark,
164
165
166
167
168
        ..
    } = request.parameters;

    let temperature = temperature.unwrap_or(1.0);
    if temperature <= 0.0 {
169
170
        return Err(ValidationError::Temperature);
    }
171
172
173

    let repetition_penalty = repetition_penalty.unwrap_or(1.0);
    if repetition_penalty <= 0.0 {
174
175
        return Err(ValidationError::RepetitionPenalty);
    }
176

177
    // Different because the proto default value is not a valid value
178
    // for the user
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    let top_p = top_p
        .map(|value| {
            if value <= 0.0 || value >= 1.0 {
                return Err(ValidationError::TopP);
            }
            Ok(value)
        })
        .unwrap_or(Ok(1.0))?;

    let typical_p = typical_p
        .map(|value| {
            if value <= 0.0 || value >= 1.0 {
                return Err(ValidationError::TypicalP);
            }
            Ok(value)
        })
        .unwrap_or(Ok(1.0))?;

    let top_k: u32 = top_k
        .map(|value| {
            if value <= 0 {
200
201
                return Err(ValidationError::TopK);
            }
202
203
204
            Ok(value as u32)
        })
        .unwrap_or(Ok(0))?;
205

206
207
    if max_new_tokens == 0 {
        return Err(ValidationError::MaxNewTokens);
208
    }
209

210
    if stop_sequences.len() > max_stop_sequences {
211
        return Err(ValidationError::StopSequence(
212
            max_stop_sequences,
213
            stop_sequences.len(),
214
        ));
215
216
    }

217
    // If seed is None, assign a random one
218
    let seed = match seed {
219
220
221
        None => rng.gen(),
        Some(seed) => seed,
    };
222

223
224
225
226
227
    // Check if inputs is empty
    if request.inputs.is_empty() {
        return Err(EmptyInput);
    }

228
229
230
231
232
    // Check if truncate is strictly positive and less than max_input_length
    let truncate = truncate
        .map(|value| {
            if value == 0 || value > max_input_length {
                return Err(ValidationError::Truncate(max_input_length, value));
233
            }
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            Ok(Some(value))
        })
        .unwrap_or(Ok(None))?;

    // Get the number of tokens in the input
    let mut encoding = tokenizer
        .encode(request.inputs.clone(), true)
        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

    let (inputs, input_length) = if let Some(truncate) = truncate {
        // truncate encoding and decode new inputs
        encoding.truncate(truncate, 0, TruncationDirection::Left);
        let inputs = tokenizer
            .decode(Vec::from(encoding.get_ids()), false)
            .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
        (inputs, encoding.len())
    } else {
        (request.inputs, encoding.len())
    };

    if input_length > max_input_length {
        return Err(ValidationError::InputLength(max_input_length, input_length));
    }

    let total_tokens = input_length + max_new_tokens as usize;
    if total_tokens > max_total_tokens {
        return Err(ValidationError::MaxTotalTokens(
            max_total_tokens,
            input_length,
            max_new_tokens,
        ));
Olivier Dehaene's avatar
Olivier Dehaene committed
265
    }
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

    // Return ValidGenerateRequest
    let parameters = NextTokenChooserParameters {
        temperature,
        repetition_penalty,
        top_k,
        top_p,
        typical_p,
        do_sample,
        seed,
        watermark,
    };
    let stopping_parameters = StoppingCriteriaParameters {
        max_new_tokens,
        stop_sequences,
    };

    metrics::histogram!("tgi_request_input_length", input_length as f64);
    metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

    Ok(ValidGenerateRequest {
        inputs,
        input_length: input_length as u32,
        parameters,
        stopping_parameters,
    })
Olivier Dehaene's avatar
Olivier Dehaene committed
292
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
293
294
295

type ValidationRequest = (
    GenerateRequest,
296
    oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
297
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
298
299
);

300
301
302
303
304
305
306
307
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
    pub input_length: u32,
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
308
309
#[derive(Error, Debug)]
pub enum ValidationError {
310
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
311
    Temperature,
312
    #[error("`repetition_penalty` must be strictly positive")]
313
    RepetitionPenalty,
314
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
315
    TopP,
316
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
317
    TopK,
318
319
    #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
    Truncate(usize, usize),
320
321
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
322
    #[error("`max_new_tokens` must be strictly positive")]
323
    MaxNewTokens,
324
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
325
    MaxTotalTokens(usize, usize, u32),
326
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
327
    InputLength(usize, usize),
328
    #[error("`inputs` cannot be empty")]
329
    EmptyInput,
330
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
331
    StopSequence(usize, usize),
332
333
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
334
}