validation.rs 8.72 KB
Newer Older
1
use crate::validation::ValidationError::EmptyInput;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
/// Payload validation logic
3
use crate::{GenerateParameters, GenerateRequest};
4
5
use rand::rngs::ThreadRng;
use rand::Rng;
6
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
8
9
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};
10
use tracing::{instrument, Span};
Olivier Dehaene's avatar
Olivier Dehaene committed
11

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
12
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
13
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
14
pub struct Validation {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
15
    /// Channel to communicate with the background validation task
16
    sender: mpsc::UnboundedSender<ValidationRequest>,
Olivier Dehaene's avatar
Olivier Dehaene committed
17
18
19
}

impl Validation {
20
21
22
23
24
25
26
    pub(crate) fn new(
        workers: usize,
        tokenizer: Tokenizer,
        max_stop_sequences: usize,
        max_input_length: usize,
        max_total_tokens: usize,
    ) -> Self {
27
        // Create channel
28
        let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
Olivier Dehaene's avatar
Olivier Dehaene committed
29

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
31
32
33
        // Launch background validation task
        tokio::spawn(validation_task(
            workers,
            tokenizer,
34
            max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
            max_input_length,
36
            max_total_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
38
            validation_receiver,
        ));
Olivier Dehaene's avatar
Olivier Dehaene committed
39
40
41
42
43
44

        Self {
            sender: validation_sender,
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
45
    /// Validate a payload and get the number of tokens in the input
46
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
47
48
49
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
50
    ) -> Result<ValidGenerateRequest, ValidationError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
51
        // Create response channel
Olivier Dehaene's avatar
Olivier Dehaene committed
52
        let (sender, receiver) = oneshot::channel();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
53
54
        // Send request to the background validation task
        // Unwrap is safe here
55
56
57
        self.sender
            .send((request, sender, Span::current()))
            .unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
58
59
        // Await on response channel
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
60
61
62
63
        receiver.await.unwrap()
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
64
65
66
67
68
/// Validation task
/// Load balance the validation requests between multiple validation workers
async fn validation_task(
    workers: usize,
    tokenizer: Tokenizer,
69
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
70
    max_input_length: usize,
71
    max_total_tokens: usize,
72
    mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
73
74
75
76
77
) {
    let mut workers_senders = Vec::with_capacity(workers);

    // Create workers
    for _ in 0..workers {
78
        let tokenizer_clone: Tokenizer = tokenizer.clone().into();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
79
80
81
82
83
84
        // Create channel to communicate with worker
        let (worker_sender, worker_receiver) = mpsc::channel(workers);
        workers_senders.push(worker_sender);

        // Spawn worker
        tokio::task::spawn_blocking(move || {
85
86
87
88
89
90
91
            validation_worker(
                tokenizer_clone,
                max_stop_sequences,
                max_input_length,
                max_total_tokens,
                worker_receiver,
            )
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        });
    }

    loop {
        // Load balance requests between workers
        for sender in workers_senders.iter() {
            if let Some(validation_request) = receiver.recv().await {
                sender.send(validation_request).await.unwrap();
            } else {
                return;
            }
        }
    }
}

/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn validation_worker(
110
    tokenizer: Tokenizer,
111
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
112
    max_input_length: usize,
113
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
114
115
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
116
117
118
    // Seed rng
    let mut rng = rand::thread_rng();

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
119
    // Loop over requests
120
121
122
123
    while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
        parent_span.in_scope(|| {
            response_tx
                .send(
124
125
126
127
128
129
130
131
132
                    validate(
                        request,
                        &tokenizer,
                        max_stop_sequences,
                        max_input_length,
                        max_total_tokens,
                        &mut rng,
                    )
                    .map_err(|err| {
133
                        metrics::increment_counter!("tgi_request_failure", "err" => "validation");
134
135
136
137
138
139
                        tracing::error!("{err}");
                        err
                    }),
                )
                .unwrap_or(())
        })
140
141
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
142

143
fn validate(
144
    request: GenerateRequest,
145
    tokenizer: &Tokenizer,
146
    max_stop_sequences: usize,
147
    max_input_length: usize,
148
    max_total_tokens: usize,
149
    rng: &mut ThreadRng,
150
) -> Result<ValidGenerateRequest, ValidationError> {
151
152
153
154
155
156
157
158
159
    let GenerateParameters {
        temperature,
        repetition_penalty,
        top_k,
        top_p,
        do_sample,
        max_new_tokens,
        stop: stop_sequences,
        seed,
160
        watermark,
161
162
163
164
165
        ..
    } = request.parameters;

    let temperature = temperature.unwrap_or(1.0);
    if temperature <= 0.0 {
166
167
        return Err(ValidationError::Temperature);
    }
168
169
170

    let repetition_penalty = repetition_penalty.unwrap_or(1.0);
    if repetition_penalty <= 0.0 {
171
172
        return Err(ValidationError::RepetitionPenalty);
    }
173
174
175

    let top_p = top_p.unwrap_or(1.0);
    if top_p <= 0.0 || top_p > 1.0 {
176
177
        return Err(ValidationError::TopP);
    }
178
179
180
181
182
183
184
185
186
187
188
189
190

    // Different because the proto default value is 0 while it is not a valid value
    // for the user
    let top_k: u32 = match top_k {
        None => Ok(0),
        Some(top_k) => {
            if top_k <= 0 {
                return Err(ValidationError::TopK);
            }
            Ok(top_k as u32)
        }
    }?;

191
192
    if max_new_tokens == 0 {
        return Err(ValidationError::MaxNewTokens);
193
    }
194

195
    if stop_sequences.len() > max_stop_sequences {
196
        return Err(ValidationError::StopSequence(
197
            max_stop_sequences,
198
            stop_sequences.len(),
199
        ));
200
201
    }

202
    // If seed is None, assign a random one
203
    let seed = match seed {
204
205
206
        None => rng.gen(),
        Some(seed) => seed,
    };
207

208
209
210
211
212
    // Check if inputs is empty
    if request.inputs.is_empty() {
        return Err(EmptyInput);
    }

213
214
    // Get the number of tokens in the input
    match tokenizer.encode(request.inputs.clone(), true) {
215
216
        Ok(encoding) => {
            let input_length = encoding.len();
217
            let total_tokens = input_length + max_new_tokens as usize;
218

219
            if input_length > max_input_length {
220
221
222
223
224
225
226
                Err(ValidationError::InputLength(max_input_length, input_length))
            } else if total_tokens > max_total_tokens {
                Err(ValidationError::MaxTotalTokens(
                    max_total_tokens,
                    input_length,
                    max_new_tokens,
                ))
227
            } else {
228
229
230
                // Return ValidGenerateRequest
                let parameters = NextTokenChooserParameters {
                    temperature,
231
                    repetition_penalty,
232
                    top_k,
233
234
235
                    top_p,
                    do_sample,
                    seed,
236
                    watermark,
237
238
239
240
241
242
                };
                let stopping_parameters = StoppingCriteriaParameters {
                    max_new_tokens,
                    stop_sequences,
                };

243
244
245
                metrics::histogram!("tgi_request_input_length", input_length as f64);
                metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);

246
247
248
249
250
251
                Ok(ValidGenerateRequest {
                    inputs: request.inputs,
                    input_length: input_length as u32,
                    parameters,
                    stopping_parameters,
                })
252
            }
253
        }
254
        Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
Olivier Dehaene's avatar
Olivier Dehaene committed
255
256
    }
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
257
258
259

type ValidationRequest = (
    GenerateRequest,
260
    oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
261
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
262
263
);

264
265
266
267
268
269
270
271
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
    pub input_length: u32,
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
272
273
#[derive(Error, Debug)]
pub enum ValidationError {
274
    #[error("`temperature` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
275
    Temperature,
276
    #[error("`repetition_penalty` must be strictly positive")]
277
    RepetitionPenalty,
278
    #[error("`top_p` must be > 0.0 and <= 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
279
    TopP,
280
    #[error("`top_k` must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
281
    TopK,
282
    #[error("`max_new_tokens` must be strictly positive")]
283
    MaxNewTokens,
284
    #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
285
    MaxTotalTokens(usize, usize, u32),
286
    #[error("`inputs` must have less than {0} tokens. Given: {1}")]
287
    InputLength(usize, usize),
288
    #[error("`inputs` cannot be empty")]
289
    EmptyInput,
290
    #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
291
    StopSequence(usize, usize),
292
293
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
294
}