validation.rs 4.86 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
2
/// Payload validation logic
use crate::GenerateRequest;
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
use axum::http::StatusCode;
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
5
use tokenizers::tokenizer::Tokenizer;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
6
7
8
9
use tokenizers::{
    DecoderWrapper, ModelWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper,
    TokenizerImpl,
};
Olivier Dehaene's avatar
Olivier Dehaene committed
10
11
use tokio::sync::{mpsc, oneshot};

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
12
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
13
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
14
pub struct Validation {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
15
    /// Channel to communicate with the background validation task
Olivier Dehaene's avatar
Olivier Dehaene committed
16
17
18
19
    sender: mpsc::Sender<ValidationRequest>,
}

impl Validation {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
20
21
    pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
        // Crate channel
Olivier Dehaene's avatar
Olivier Dehaene committed
22
23
        let (validation_sender, validation_receiver) = mpsc::channel(128);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
24
25
26
27
28
29
30
        // Launch background validation task
        tokio::spawn(validation_task(
            workers,
            tokenizer,
            max_input_length,
            validation_receiver,
        ));
Olivier Dehaene's avatar
Olivier Dehaene committed
31
32
33
34
35
36

        Self {
            sender: validation_sender,
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
    /// Validate a payload and get the number of tokens in the input
Olivier Dehaene's avatar
Olivier Dehaene committed
38
39
40
41
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
    ) -> Result<(usize, GenerateRequest), ValidationError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
42
        // Create response channel
Olivier Dehaene's avatar
Olivier Dehaene committed
43
        let (sender, receiver) = oneshot::channel();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
44
45
        // Send request to the background validation task
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
46
        self.sender.send((request, sender)).await.unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
47
48
        // Await on response channel
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
49
50
51
52
        receiver.await.unwrap()
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
/// Validation task
/// Load balance the validation requests between multiple validation workers
async fn validation_task(
    workers: usize,
    tokenizer: Tokenizer,
    max_input_length: usize,
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
    let mut workers_senders = Vec::with_capacity(workers);

    // Create workers
    for _ in 0..workers {
        let tokenizer_clone = tokenizer.clone();
        // Create channel to communicate with worker
        let (worker_sender, worker_receiver) = mpsc::channel(workers);
        workers_senders.push(worker_sender);

        // Spawn worker
        tokio::task::spawn_blocking(move || {
            validation_worker(tokenizer_clone, max_input_length, worker_receiver)
        });
    }

    loop {
        // Load balance requests between workers
        for sender in workers_senders.iter() {
            if let Some(validation_request) = receiver.recv().await {
                sender.send(validation_request).await.unwrap();
            } else {
                return;
            }
        }
    }
}

/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn validation_worker(
    tokenizer: TokenizerImpl<
        ModelWrapper,
        NormalizerWrapper,
        PreTokenizerWrapper,
        PostProcessorWrapper,
        DecoderWrapper,
    >,
    max_input_length: usize,
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
    // Loop over requests
    while let Some((request, response_tx)) = receiver.blocking_recv() {
Olivier Dehaene's avatar
Olivier Dehaene committed
103
        if request.parameters.temperature < 0.0 {
Olivier Dehaene's avatar
Olivier Dehaene committed
104
105
106
            response_tx
                .send(Err(ValidationError::Temperature))
                .unwrap_or(());
Olivier Dehaene's avatar
Olivier Dehaene committed
107
108
109
            continue;
        }
        if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
Olivier Dehaene's avatar
Olivier Dehaene committed
110
111
112
113
114
            response_tx.send(Err(ValidationError::TopP)).unwrap_or(());
            continue;
        }
        if request.parameters.top_k < 0 {
            response_tx.send(Err(ValidationError::TopK)).unwrap_or(());
Olivier Dehaene's avatar
Olivier Dehaene committed
115
116
117
            continue;
        }
        if request.parameters.max_new_tokens > 512 {
Olivier Dehaene's avatar
Olivier Dehaene committed
118
119
120
            response_tx
                .send(Err(ValidationError::MaxNewTokens))
                .unwrap_or(());
Olivier Dehaene's avatar
Olivier Dehaene committed
121
122
123
            continue;
        }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
124
        // Get the number of tokens in the input
Olivier Dehaene's avatar
Olivier Dehaene committed
125
126
127
        let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap();
        let input_length = inputs.len();

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
128
        if input_length > max_input_length {
Olivier Dehaene's avatar
Olivier Dehaene committed
129
            response_tx
130
                .send(Err(ValidationError::InputLength(input_length, max_input_length)))
Olivier Dehaene's avatar
Olivier Dehaene committed
131
                .unwrap_or(());
Olivier Dehaene's avatar
Olivier Dehaene committed
132
133
134
135
136
137
            continue;
        }

        response_tx.send(Ok((input_length, request))).unwrap_or(());
    }
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
138
139
140
141
142
143
144
145
146
147

type ValidationRequest = (
    GenerateRequest,
    oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
);

#[derive(Error, Debug)]
pub enum ValidationError {
    #[error("Temperature must be strictly positive")]
    Temperature,
148
    #[error("Top p must be >= 0.0 or < 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
149
150
151
    TopP,
    #[error("Top k must be strictly positive")]
    TopK,
152
    #[error("Max New Tokens must be <= 512")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
153
    MaxNewTokens,
154
155
    #[error("Inputs must have less than {1} tokens. Given: {0}")]
    InputLength(usize, usize),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
156
157
158
159
160
161
162
}

impl From<ValidationError> for (StatusCode, String) {
    fn from(err: ValidationError) -> Self {
        (StatusCode::BAD_REQUEST, err.to_string())
    }
}