validation.rs 8.37 KB
Newer Older
1
use crate::validation::ValidationError::EmptyInput;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
/// Payload validation logic
3
use crate::{GenerateParameters, GenerateRequest};
4
5
use rand::rngs::ThreadRng;
use rand::Rng;
6
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
8
9
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};
10
use tracing::{instrument, Span};
Olivier Dehaene's avatar
Olivier Dehaene committed
11

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
12
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
13
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
14
pub struct Validation {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
15
    /// Channel to communicate with the background validation task
Olivier Dehaene's avatar
Olivier Dehaene committed
16
17
18
19
    sender: mpsc::Sender<ValidationRequest>,
}

impl Validation {
20
21
22
23
24
25
26
    pub(crate) fn new(
        workers: usize,
        tokenizer: Tokenizer,
        max_stop_sequences: usize,
        max_input_length: usize,
        max_total_tokens: usize,
    ) -> Self {
27
        // Create channel
Olivier Dehaene's avatar
Olivier Dehaene committed
28
29
        let (validation_sender, validation_receiver) = mpsc::channel(128);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
31
32
33
        // Launch background validation task
        tokio::spawn(validation_task(
            workers,
            tokenizer,
34
            max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
            max_input_length,
36
            max_total_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
38
            validation_receiver,
        ));
Olivier Dehaene's avatar
Olivier Dehaene committed
39
40
41
42
43
44

        Self {
            sender: validation_sender,
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
45
    /// Validate a payload and get the number of tokens in the input
46
    #[instrument(skip_all)]
Olivier Dehaene's avatar
Olivier Dehaene committed
47
48
49
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
50
    ) -> Result<ValidGenerateRequest, ValidationError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
51
        // Create response channel
Olivier Dehaene's avatar
Olivier Dehaene committed
52
        let (sender, receiver) = oneshot::channel();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
53
54
        // Send request to the background validation task
        // Unwrap is safe here
55
56
57
58
        self.sender
            .send((request, sender, Span::current()))
            .await
            .unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
59
60
        // Await on response channel
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
61
62
63
64
        receiver.await.unwrap()
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
65
66
67
68
69
/// Validation task
/// Load balance the validation requests between multiple validation workers
async fn validation_task(
    workers: usize,
    tokenizer: Tokenizer,
70
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
71
    max_input_length: usize,
72
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
73
74
75
76
77
78
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
    let mut workers_senders = Vec::with_capacity(workers);

    // Create workers
    for _ in 0..workers {
79
        let tokenizer_clone: Tokenizer = tokenizer.clone().into();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
80
81
82
83
84
85
        // Create channel to communicate with worker
        let (worker_sender, worker_receiver) = mpsc::channel(workers);
        workers_senders.push(worker_sender);

        // Spawn worker
        tokio::task::spawn_blocking(move || {
86
87
88
89
90
91
92
            validation_worker(
                tokenizer_clone,
                max_stop_sequences,
                max_input_length,
                max_total_tokens,
                worker_receiver,
            )
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        });
    }

    loop {
        // Load balance requests between workers
        for sender in workers_senders.iter() {
            if let Some(validation_request) = receiver.recv().await {
                sender.send(validation_request).await.unwrap();
            } else {
                return;
            }
        }
    }
}

/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn validation_worker(
111
    tokenizer: Tokenizer,
112
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
113
    max_input_length: usize,
114
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
115
116
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
117
118
119
    // Seed rng
    let mut rng = rand::thread_rng();

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
120
    // Loop over requests
121
122
123
124
    while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
        parent_span.in_scope(|| {
            response_tx
                .send(
125
126
127
128
129
130
131
132
133
                    validate(
                        request,
                        &tokenizer,
                        max_stop_sequences,
                        max_input_length,
                        max_total_tokens,
                        &mut rng,
                    )
                    .map_err(|err| {
134
135
136
137
138
139
                        tracing::error!("{err}");
                        err
                    }),
                )
                .unwrap_or(())
        })
140
141
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
142

143
fn validate(
144
    request: GenerateRequest,
145
    tokenizer: &Tokenizer,
146
    max_stop_sequences: usize,
147
    max_input_length: usize,
148
    max_total_tokens: usize,
149
    rng: &mut ThreadRng,
150
) -> Result<ValidGenerateRequest, ValidationError> {
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    let GenerateParameters {
        temperature,
        repetition_penalty,
        top_k,
        top_p,
        do_sample,
        max_new_tokens,
        stop: stop_sequences,
        seed,
        ..
    } = request.parameters;

    let temperature = temperature.unwrap_or(1.0);
    if temperature <= 0.0 {
165
166
        return Err(ValidationError::Temperature);
    }
167
168
169

    let repetition_penalty = repetition_penalty.unwrap_or(1.0);
    if repetition_penalty <= 0.0 {
170
171
        return Err(ValidationError::RepetitionPenalty);
    }
172
173
174

    let top_p = top_p.unwrap_or(1.0);
    if top_p <= 0.0 || top_p > 1.0 {
175
176
        return Err(ValidationError::TopP);
    }
177
178
179
180
181
182
183
184
185
186
187
188
189

    // Different because the proto default value is 0 while it is not a valid value
    // for the user
    let top_k: u32 = match top_k {
        None => Ok(0),
        Some(top_k) => {
            if top_k <= 0 {
                return Err(ValidationError::TopK);
            }
            Ok(top_k as u32)
        }
    }?;

190
191
    if max_new_tokens == 0 {
        return Err(ValidationError::MaxNewTokens);
192
    }
193

194
    if stop_sequences.len() > max_stop_sequences {
195
        return Err(ValidationError::StopSequence(
196
            max_stop_sequences,
197
            stop_sequences.len(),
198
        ));
199
200
    }

201
    // If seed is None, assign a random one
202
    let seed = match seed {
203
204
205
        None => rng.gen(),
        Some(seed) => seed,
    };
206

207
208
209
210
211
    // Check if inputs is empty
    if request.inputs.is_empty() {
        return Err(EmptyInput);
    }

212
213
    // Get the number of tokens in the input
    match tokenizer.encode(request.inputs.clone(), true) {
214
215
        Ok(encoding) => {
            let input_length = encoding.len();
216
            let total_tokens = input_length + max_new_tokens as usize;
217
            if input_length > max_input_length {
218
219
220
221
222
223
224
                Err(ValidationError::InputLength(max_input_length, input_length))
            } else if total_tokens > max_total_tokens {
                Err(ValidationError::MaxTotalTokens(
                    max_total_tokens,
                    input_length,
                    max_new_tokens,
                ))
225
            } else {
226
227
228
                // Return ValidGenerateRequest
                let parameters = NextTokenChooserParameters {
                    temperature,
229
                    repetition_penalty,
230
                    top_k,
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
                    top_p,
                    do_sample,
                    seed,
                };
                let stopping_parameters = StoppingCriteriaParameters {
                    max_new_tokens,
                    stop_sequences,
                };

                Ok(ValidGenerateRequest {
                    inputs: request.inputs,
                    input_length: input_length as u32,
                    parameters,
                    stopping_parameters,
                })
246
            }
247
        }
248
        Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
Olivier Dehaene's avatar
Olivier Dehaene committed
249
250
    }
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
251
252
253

type ValidationRequest = (
    GenerateRequest,
254
    oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
255
    Span,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
256
257
);

258
259
260
261
262
263
264
265
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
    pub inputs: String,
    pub input_length: u32,
    pub parameters: NextTokenChooserParameters,
    pub stopping_parameters: StoppingCriteriaParameters,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
266
267
#[derive(Error, Debug)]
pub enum ValidationError {
268
    #[error("temperature must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
269
    Temperature,
270
271
    #[error("repetition_penalty must be strictly positive")]
    RepetitionPenalty,
272
    #[error("top_p must be > 0.0 and <= 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
273
    TopP,
274
    #[error("top_k must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
275
    TopK,
276
277
278
279
280
    #[error("max_new_tokens must be strictly positive")]
    MaxNewTokens,
    #[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")]
    MaxTotalTokens(usize, usize, u32),
    #[error("inputs must have less than {0} tokens. Given: {1}")]
281
    InputLength(usize, usize),
282
283
    #[error("inputs cannot be empty")]
    EmptyInput,
284
285
    #[error("stop supports up to {0} stop sequences. Given: {1}")]
    StopSequence(usize, usize),
286
287
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
288
}