validation.rs 5.73 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Payload validation logic
2
use crate::{ErrorResponse, GenerateRequest};
Olivier Dehaene's avatar
Olivier Dehaene committed
3
use axum::http::StatusCode;
4
use axum::Json;
Olivier Dehaene's avatar
Olivier Dehaene committed
5
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use tokenizers::tokenizer::Tokenizer;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
7
8
9
10
use tokenizers::{
    DecoderWrapper, ModelWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper,
    TokenizerImpl,
};
Olivier Dehaene's avatar
Olivier Dehaene committed
11
12
use tokio::sync::{mpsc, oneshot};

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
14
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
15
pub struct Validation {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
16
    /// Channel to communicate with the background validation task
Olivier Dehaene's avatar
Olivier Dehaene committed
17
18
19
20
    sender: mpsc::Sender<ValidationRequest>,
}

impl Validation {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
21
22
    pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
        // Crate channel
Olivier Dehaene's avatar
Olivier Dehaene committed
23
24
        let (validation_sender, validation_receiver) = mpsc::channel(128);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
25
26
27
28
29
30
31
        // Launch background validation task
        tokio::spawn(validation_task(
            workers,
            tokenizer,
            max_input_length,
            validation_receiver,
        ));
Olivier Dehaene's avatar
Olivier Dehaene committed
32
33
34
35
36
37

        Self {
            sender: validation_sender,
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
38
    /// Validate a payload and get the number of tokens in the input
Olivier Dehaene's avatar
Olivier Dehaene committed
39
40
41
42
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
    ) -> Result<(usize, GenerateRequest), ValidationError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
43
        // Create response channel
Olivier Dehaene's avatar
Olivier Dehaene committed
44
        let (sender, receiver) = oneshot::channel();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
45
46
        // Send request to the background validation task
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
47
        self.sender.send((request, sender)).await.unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
48
49
        // Await on response channel
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
50
51
52
53
        receiver.await.unwrap()
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/// Validation task
/// Load balance the validation requests between multiple validation workers
async fn validation_task(
    workers: usize,
    tokenizer: Tokenizer,
    max_input_length: usize,
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
    let mut workers_senders = Vec::with_capacity(workers);

    // Create workers
    for _ in 0..workers {
        let tokenizer_clone = tokenizer.clone();
        // Create channel to communicate with worker
        let (worker_sender, worker_receiver) = mpsc::channel(workers);
        workers_senders.push(worker_sender);

        // Spawn worker
        tokio::task::spawn_blocking(move || {
            validation_worker(tokenizer_clone, max_input_length, worker_receiver)
        });
    }

    loop {
        // Load balance requests between workers
        for sender in workers_senders.iter() {
            if let Some(validation_request) = receiver.recv().await {
                sender.send(validation_request).await.unwrap();
            } else {
                return;
            }
        }
    }
}

/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn validation_worker(
    tokenizer: TokenizerImpl<
        ModelWrapper,
        NormalizerWrapper,
        PreTokenizerWrapper,
        PostProcessorWrapper,
        DecoderWrapper,
    >,
    max_input_length: usize,
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
    // Loop over requests
    while let Some((request, response_tx)) = receiver.blocking_recv() {
104
        if request.parameters.temperature <= 0.0 {
Olivier Dehaene's avatar
Olivier Dehaene committed
105
106
107
            response_tx
                .send(Err(ValidationError::Temperature))
                .unwrap_or(());
Olivier Dehaene's avatar
Olivier Dehaene committed
108
109
110
            continue;
        }
        if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
Olivier Dehaene's avatar
Olivier Dehaene committed
111
112
113
114
115
            response_tx.send(Err(ValidationError::TopP)).unwrap_or(());
            continue;
        }
        if request.parameters.top_k < 0 {
            response_tx.send(Err(ValidationError::TopK)).unwrap_or(());
Olivier Dehaene's avatar
Olivier Dehaene committed
116
117
118
            continue;
        }
        if request.parameters.max_new_tokens > 512 {
Olivier Dehaene's avatar
Olivier Dehaene committed
119
120
121
            response_tx
                .send(Err(ValidationError::MaxNewTokens))
                .unwrap_or(());
Olivier Dehaene's avatar
Olivier Dehaene committed
122
123
            continue;
        }
124
125
126
127
128
129
130
131
        if request.parameters.stop.len() > 4 {
            response_tx
                .send(Err(ValidationError::StopSequence(
                    request.parameters.stop.len(),
                )))
                .unwrap_or(());
            continue;
        }
Olivier Dehaene's avatar
Olivier Dehaene committed
132

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
133
        // Get the number of tokens in the input
134
        match tokenizer.encode(request.inputs.clone(), true) {
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
            Ok(inputs) => {
                let input_length = inputs.len();

                if input_length > max_input_length {
                    response_tx
                        .send(Err(ValidationError::InputLength(
                            input_length,
                            max_input_length,
                        )))
                        .unwrap_or(());
                    continue;
                }

                response_tx.send(Ok((input_length, request))).unwrap_or(());
            }
            Err(err) => response_tx
                .send(Err(ValidationError::Tokenizer(err.to_string())))
                .unwrap_or(()),
        };
Olivier Dehaene's avatar
Olivier Dehaene committed
154
155
    }
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
156
157
158
159
160
161
162
163

type ValidationRequest = (
    GenerateRequest,
    oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
);

#[derive(Error, Debug)]
pub enum ValidationError {
164
    #[error("temperature must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
165
    Temperature,
166
    #[error("top_p must be > 0.0 and <= 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
167
    TopP,
168
    #[error("top_k must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
169
    TopK,
170
    #[error("max_new_tokens must be <= 512")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
171
    MaxNewTokens,
172
    #[error("inputs must have less than {1} tokens. Given: {0}")]
173
    InputLength(usize, usize),
174
175
    #[error("stop supports up to 4 stop sequences. Given: {0}")]
    StopSequence(usize),
176
177
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
178
179
}

180
impl From<ValidationError> for (StatusCode, Json<ErrorResponse>) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
181
    fn from(err: ValidationError) -> Self {
182
        (
183
            StatusCode::UNPROCESSABLE_ENTITY,
184
185
186
187
            Json(ErrorResponse {
                error: err.to_string(),
            }),
        )
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
188
189
    }
}