validation.rs 9.13 KB
Newer Older
1
use crate::validation::ValidationError::EmptyInput;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
/// Payload validation logic
3
use crate::{GenerateParameters, GenerateRequest};
4
5
use rand::rngs::ThreadRng;
use rand::Rng;
6
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
8
9
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};
10
use tracing::{instrument, Span};
Olivier Dehaene's avatar
Olivier Dehaene committed
11

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
12
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
13
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
14
pub struct Validation {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
15
    /// Channel to communicate with the background validation task
16
    sender: mpsc::UnboundedSender<ValidationRequest>,
Olivier Dehaene's avatar
Olivier Dehaene committed
17
18
19
}

impl Validation {
20
21
22
23
24
25
26
    pub(crate) fn new(
        workers: usize,
        tokenizer: Tokenizer,
        max_stop_sequences: usize,
        max_input_length: usize,
        max_total_tokens: usize,
    ) -> Self {
27
        // Create channel
28
        let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
Olivier Dehaene's avatar
Olivier Dehaene committed
29

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
31
32
33
        // Launch background validation task
        tokio::spawn(validation_task(
            workers,
            tokenizer,
34
            max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
            max_input_length,
36
            max_total_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
38
            validation_receiver,
        ));
Olivier Dehaene's avatar
Olivier Dehaene committed
39
40
41
42
43
44

        Self {
            sender: validation_sender,
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
45
    /// Validate a payload and get the number of tokens in the input
46
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
47
48
49
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
50
    ) -> Result<ValidGenerateRequest, ValidationError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
51
        // Create response channel
Olivier Dehaene's avatar
Olivier Dehaene committed
52
        let (sender, receiver) = oneshot::channel();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
53
54
        // Send request to the background validation task
        // Unwrap is safe here
55
56
57
        self.sender
            .send((request, sender, Span::current()))
            .unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
58
59
        // Await on response channel
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
60
61
62
63
        receiver.await.unwrap()
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
64
65
66
67
68
/// Validation task
/// Load balance the validation requests between multiple validation workers
async fn validation_task(
    workers: usize,
    tokenizer: Tokenizer,
69
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
70
    max_input_length: usize,
71
    max_total_tokens: usize,
72
    mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
73
74
75
76
77
) {
    let mut workers_senders = Vec::with_capacity(workers);

    // Create workers
    for _ in 0..workers {
78
        let tokenizer_clone: Tokenizer = tokenizer.clone().into();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
79
80
81
82
83
84
        // Create channel to communicate with worker
        let (worker_sender, worker_receiver) = mpsc::channel(workers);
        workers_senders.push(worker_sender);

        // Spawn worker
        tokio::task::spawn_blocking(move || {
85
86
87
88
89
90
91
            validation_worker(
                tokenizer_clone,
                max_stop_sequences,
                max_input_length,
                max_total_tokens,
                worker_receiver,
            )
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        });
    }

    loop {
        // Load balance requests between workers
        for sender in workers_senders.iter() {
            if let Some(validation_request) = receiver.recv().await {
                sender.send(validation_request).await.unwrap();
            } else {
                return;
            }
        }
    }
}

/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn validation_worker(
110
    tokenizer: Tokenizer,
111
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
112
    max_input_length: usize,
113
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
114
115
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
116
117
118
    // Seed rng
    let mut rng = rand::thread_rng();

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
119
    // Loop over requests
120
121
122
123
    while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
        parent_span.in_scope(|| {
            response_tx
                .send(
124
125
126
127
128
129
130
131
132
                    validate(
                        request,
                        &tokenizer,
                        max_stop_sequences,
                        max_input_length,
                        max_total_tokens,
                        &mut rng,
                    )
                    .map_err(|err| {
133
                        metrics::increment_counter!("tgi_request_failure", "err" => "validation");
134
135
136
137
138
139
                        tracing::error!("{err}");
                        err
                    }),
                )
                .unwrap_or(())
        })
140
141
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
142

143
fn validate(
144
    request: GenerateRequest,
145
    tokenizer: &Tokenizer,
146
    max_stop_sequences: usize,
147
    max_input_length: usize,
148
    max_total_tokens: usize,
149
    rng: &mut ThreadRng,
150
) -> Result<ValidGenerateRequest, ValidationError> {
151
152
153
154
155
    let GenerateParameters {
        temperature,
        repetition_penalty,
        top_k,
        top_p,
156
        typical_p,
157
158
159
160
        do_sample,
        max_new_tokens,
        stop: stop_sequences,
        seed,
161
        watermark,
162
163
164
165
166
        ..
    } = request.parameters;

    let temperature = temperature.unwrap_or(1.0);
    if temperature <= 0.0 {
167
168
        return Err(ValidationError::Temperature);
    }
169
170
171

    let repetition_penalty = repetition_penalty.unwrap_or(1.0);
    if repetition_penalty <= 0.0 {
172
173
        return Err(ValidationError::RepetitionPenalty);
    }
174

175
    // Different because the proto default value is not a valid value
176
    // for the user
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    let top_p = top_p
        .map(|value| {
            if value <= 0.0 || value >= 1.0 {
                return Err(ValidationError::TopP);
            }
            Ok(value)
        })
        .unwrap_or(Ok(1.0))?;

    let typical_p = typical_p
        .map(|value| {
            if value <= 0.0 || value >= 1.0 {
                return Err(ValidationError::TypicalP);
            }
            Ok(value)
        })
        .unwrap_or(Ok(1.0))?;

    let top_k: u32 = top_k
        .map(|value| {
            if value <= 0 {
198
199
                return Err(ValidationError::TopK);
            }
200
201
202
            Ok(value as u32)
        })
        .unwrap_or(Ok(0))?;
203

204
205
    if max_new_tokens == 0 {
        return Err(ValidationError::MaxNewTokens);
206
    }
207

208
    if stop_sequences.len() > max_stop_sequences {
209
        return Err(ValidationError::StopSequence(
210
            max_stop_sequences,
211
            stop_sequences.len(),
212
        ));
213
214
    }

215
    // If seed is None, assign a random one
216
    let seed = match seed {
217
218
219
        None => rng.gen(),
        Some(seed) => seed,
    };
220

221
222
223
224
225
    // Check if inputs is empty
    if request.inputs.is_empty() {
        return Err(EmptyInput);
    }

226
227
    // Get the number of tokens in the input
    match tokenizer.encode(request.inputs.clone(), true) {
228
229
        Ok(encoding) => {
            let input_length = encoding.len();
230
            let total_tokens = input_length + max_new_tokens as usize;
231

232
            if input_length > max_input_length {
233
234
235
236
237
238
239
                Err(ValidationError::InputLength(max_input_length, input_length))
            } else if total_tokens > max_total_tokens {
                Err(ValidationError::MaxTotalTokens(
                    max_total_tokens,
                    input_length,
                    max_new_tokens,
                ))
240
            } else {
241
242
243
                // Return ValidGenerateRequest
                let parameters = NextTokenChooserParameters {
                    temperature,
244
                    repetition_penalty,
245
                    top_k,
246
                    top_p,
247
                    typical_p,
248
249
                    do_sample,
                    seed,
250
                    watermark,
251
252
253
254
255
256
                };
                let stopping_parameters = StoppingCriteriaParameters {
                    max_new_tokens,
                    stop_sequences,
                };

257
258
259
                metrics::histogram!("tgi_request_input_length", input_length as f64);
                metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

260
261
262
263
264
265
                Ok(ValidGenerateRequest {
                    inputs: request.inputs,
                    input_length: input_length as u32,
                    parameters,
                    stopping_parameters,
                })
266
            }
267
        }
268
        Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
Olivier Dehaene's avatar
Olivier Dehaene committed
269
270
    }
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
271
272
273

type ValidationRequest = (
    GenerateRequest,
274
    oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
275
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
276
277
);

278
279
280
281
282
283
284
285
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
    pub input_length: u32,
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
286
287
#[derive(Error, Debug)]
pub enum ValidationError {
288
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
289
    Temperature,
290
    #[error("`repetition_penalty` must be strictly positive")]
291
    RepetitionPenalty,
292
    #[error("`top_p` must be > 0.0 and < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
293
    TopP,
294
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
295
    TopK,
296
297
    #[error("`typical_p` must be > 0.0 and < 1.0")]
    TypicalP,
298
    #[error("`max_new_tokens` must be strictly positive")]
299
    MaxNewTokens,
300
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
301
    MaxTotalTokens(usize, usize, u32),
302
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
303
    InputLength(usize, usize),
304
    #[error("`inputs` cannot be empty")]
305
    EmptyInput,
306
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
307
    StopSequence(usize, usize),
308
309
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
310
}