validation.rs 5.47 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Payload validation logic
2
3
4
use crate::{ErrorResponse, GenerateRequest};
use axum::http::StatusCode;
use axum::Json;
5
6
use rand::rngs::ThreadRng;
use rand::Rng;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
8
9
10
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};

11
12
13
const MAX_MAX_NEW_TOKENS: u32 = 512;
const MAX_STOP_SEQUENCES: usize = 4;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
15
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
16
pub struct Validation {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
17
    /// Channel to communicate with the background validation task
Olivier Dehaene's avatar
Olivier Dehaene committed
18
19
20
21
    sender: mpsc::Sender<ValidationRequest>,
}

impl Validation {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
22
23
    pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
        // Crate channel
Olivier Dehaene's avatar
Olivier Dehaene committed
24
25
        let (validation_sender, validation_receiver) = mpsc::channel(128);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
26
27
28
29
30
31
32
        // Launch background validation task
        tokio::spawn(validation_task(
            workers,
            tokenizer,
            max_input_length,
            validation_receiver,
        ));
Olivier Dehaene's avatar
Olivier Dehaene committed
33
34
35
36
37
38

        Self {
            sender: validation_sender,
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
39
    /// Validate a payload and get the number of tokens in the input
Olivier Dehaene's avatar
Olivier Dehaene committed
40
41
42
43
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
    ) -> Result<(usize, GenerateRequest), ValidationError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
44
        // Create response channel
Olivier Dehaene's avatar
Olivier Dehaene committed
45
        let (sender, receiver) = oneshot::channel();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
46
47
        // Send request to the background validation task
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
48
        self.sender.send((request, sender)).await.unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
49
50
        // Await on response channel
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
51
52
53
54
        receiver.await.unwrap()
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
55
56
57
58
59
60
61
62
63
64
65
66
/// Validation task
/// Load balance the validation requests between multiple validation workers
async fn validation_task(
    workers: usize,
    tokenizer: Tokenizer,
    max_input_length: usize,
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
    let mut workers_senders = Vec::with_capacity(workers);

    // Create workers
    for _ in 0..workers {
67
        let tokenizer_clone: Tokenizer = tokenizer.clone().into();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        // Create channel to communicate with worker
        let (worker_sender, worker_receiver) = mpsc::channel(workers);
        workers_senders.push(worker_sender);

        // Spawn worker
        tokio::task::spawn_blocking(move || {
            validation_worker(tokenizer_clone, max_input_length, worker_receiver)
        });
    }

    loop {
        // Load balance requests between workers
        for sender in workers_senders.iter() {
            if let Some(validation_request) = receiver.recv().await {
                sender.send(validation_request).await.unwrap();
            } else {
                return;
            }
        }
    }
}

/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn validation_worker(
93
    tokenizer: Tokenizer,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
94
95
96
    max_input_length: usize,
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
97
98
99
    // Seed rng
    let mut rng = rand::thread_rng();

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
100
101
    // Loop over requests
    while let Some((request, response_tx)) = receiver.blocking_recv() {
102
        response_tx
103
            .send(validate(request, &tokenizer, max_input_length, &mut rng))
104
            .unwrap_or(())
105
106
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
107

108
fn validate(
109
    mut request: GenerateRequest,
110
111
    tokenizer: &Tokenizer,
    max_input_length: usize,
112
    rng: &mut ThreadRng,
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
) -> Result<(usize, GenerateRequest), ValidationError> {
    if request.parameters.temperature <= 0.0 {
        return Err(ValidationError::Temperature);
    }
    if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
        return Err(ValidationError::TopP);
    }
    if request.parameters.top_k < 0 {
        return Err(ValidationError::TopK);
    }
    if request.parameters.max_new_tokens > MAX_MAX_NEW_TOKENS {
        return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
    }
    if request.parameters.stop.len() > MAX_STOP_SEQUENCES {
        return Err(ValidationError::StopSequence(
128
129
130
            MAX_STOP_SEQUENCES,
            request.parameters.stop.len(),
        ));
131
132
    }

133
134
135
136
137
    // If seed is None, assign a random one
    if request.parameters.seed.is_none() {
        request.parameters.seed = Some(rng.gen());
    }

138
139
140
141
142
143
    // Get the number of tokens in the input
    match tokenizer.encode(request.inputs.clone(), true) {
        Ok(inputs) => {
            let input_length = inputs.len();

            if input_length > max_input_length {
144
                Err(ValidationError::InputLength(input_length, max_input_length))
145
146
            } else {
                Ok((input_length, request))
147
            }
148
        }
149
        Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
Olivier Dehaene's avatar
Olivier Dehaene committed
150
151
    }
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
152
153
154
155
156
157
158
159

type ValidationRequest = (
    GenerateRequest,
    oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
);

#[derive(Error, Debug)]
pub enum ValidationError {
160
    #[error("temperature must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
161
    Temperature,
162
    #[error("top_p must be > 0.0 and <= 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
163
    TopP,
164
    #[error("top_k must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
165
    TopK,
166
167
    #[error("max_new_tokens must be <= {0}")]
    MaxNewTokens(u32),
168
    #[error("inputs must have less than {1} tokens. Given: {0}")]
169
    InputLength(usize, usize),
170
171
    #[error("stop supports up to {0} stop sequences. Given: {1}")]
    StopSequence(usize, usize),
172
173
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
174
}
175
176
177
178
179
180
181
182
183
184
185

impl From<ValidationError> for (StatusCode, Json<ErrorResponse>) {
    fn from(err: ValidationError) -> Self {
        (
            StatusCode::UNPROCESSABLE_ENTITY,
            Json(ErrorResponse {
                error: err.to_string(),
            }),
        )
    }
}