validation.rs 4.87 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Payload validation logic
2
use crate::GenerateRequest;
Olivier Dehaene's avatar
Olivier Dehaene committed
3
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
4
5
6
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};

7
8
9
const MAX_MAX_NEW_TOKENS: u32 = 512;
const MAX_STOP_SEQUENCES: usize = 4;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
10
/// Validation
Olivier Dehaene's avatar
Olivier Dehaene committed
11
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
12
pub struct Validation {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
    /// Channel to communicate with the background validation task
Olivier Dehaene's avatar
Olivier Dehaene committed
14
15
16
17
    sender: mpsc::Sender<ValidationRequest>,
}

impl Validation {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
19
    pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
        // Crate channel
Olivier Dehaene's avatar
Olivier Dehaene committed
20
21
        let (validation_sender, validation_receiver) = mpsc::channel(128);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
22
23
24
25
26
27
28
        // Launch background validation task
        tokio::spawn(validation_task(
            workers,
            tokenizer,
            max_input_length,
            validation_receiver,
        ));
Olivier Dehaene's avatar
Olivier Dehaene committed
29
30
31
32
33
34

        Self {
            sender: validation_sender,
        }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
    /// Validate a payload and get the number of tokens in the input
Olivier Dehaene's avatar
Olivier Dehaene committed
36
37
38
39
    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
    ) -> Result<(usize, GenerateRequest), ValidationError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
40
        // Create response channel
Olivier Dehaene's avatar
Olivier Dehaene committed
41
        let (sender, receiver) = oneshot::channel();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
42
43
        // Send request to the background validation task
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
44
        self.sender.send((request, sender)).await.unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
45
46
        // Await on response channel
        // Unwrap is safe here
Olivier Dehaene's avatar
Olivier Dehaene committed
47
48
49
50
        receiver.await.unwrap()
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
51
52
53
54
55
56
57
58
59
60
61
62
/// Validation task
/// Load balance the validation requests between multiple validation workers
async fn validation_task(
    workers: usize,
    tokenizer: Tokenizer,
    max_input_length: usize,
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
    let mut workers_senders = Vec::with_capacity(workers);

    // Create workers
    for _ in 0..workers {
63
        let tokenizer_clone: Tokenizer = tokenizer.clone().into();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        // Create channel to communicate with worker
        let (worker_sender, worker_receiver) = mpsc::channel(workers);
        workers_senders.push(worker_sender);

        // Spawn worker
        tokio::task::spawn_blocking(move || {
            validation_worker(tokenizer_clone, max_input_length, worker_receiver)
        });
    }

    loop {
        // Load balance requests between workers
        for sender in workers_senders.iter() {
            if let Some(validation_request) = receiver.recv().await {
                sender.send(validation_request).await.unwrap();
            } else {
                return;
            }
        }
    }
}

/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn validation_worker(
89
    tokenizer: Tokenizer,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
90
91
92
93
94
    max_input_length: usize,
    mut receiver: mpsc::Receiver<ValidationRequest>,
) {
    // Loop over requests
    while let Some((request, response_tx)) = receiver.blocking_recv() {
95
96
97
        response_tx
            .send(validate(request, &tokenizer, max_input_length))
            .unwrap_or(())
98
99
    }
}
Olivier Dehaene's avatar
Olivier Dehaene committed
100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
fn validate(
    request: GenerateRequest,
    tokenizer: &Tokenizer,
    max_input_length: usize,
) -> Result<(usize, GenerateRequest), ValidationError> {
    if request.parameters.temperature <= 0.0 {
        return Err(ValidationError::Temperature);
    }
    if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
        return Err(ValidationError::TopP);
    }
    if request.parameters.top_k < 0 {
        return Err(ValidationError::TopK);
    }
    if request.parameters.max_new_tokens > MAX_MAX_NEW_TOKENS {
        return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
    }
    if request.parameters.stop.len() > MAX_STOP_SEQUENCES {
        return Err(ValidationError::StopSequence(
120
121
122
            MAX_STOP_SEQUENCES,
            request.parameters.stop.len(),
        ));
123
124
125
126
127
128
129
130
    }

    // Get the number of tokens in the input
    match tokenizer.encode(request.inputs.clone(), true) {
        Ok(inputs) => {
            let input_length = inputs.len();

            if input_length > max_input_length {
131
                Err(ValidationError::InputLength(input_length, max_input_length))
132
133
            } else {
                Ok((input_length, request))
134
            }
135
        }
136
        Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
Olivier Dehaene's avatar
Olivier Dehaene committed
137
138
    }
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
139
140
141
142
143
144
145
146

type ValidationRequest = (
    GenerateRequest,
    oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
);

#[derive(Error, Debug)]
pub enum ValidationError {
147
    #[error("temperature must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
148
    Temperature,
149
    #[error("top_p must be > 0.0 and <= 1.0")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
150
    TopP,
151
    #[error("top_k must be strictly positive")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
152
    TopK,
153
154
    #[error("max_new_tokens must be <= {0}")]
    MaxNewTokens(u32),
155
    #[error("inputs must have less than {1} tokens. Given: {0}")]
156
    InputLength(usize, usize),
157
158
    #[error("stop supports up to {0} stop sequences. Given: {1}")]
    StopSequence(usize, usize),
159
160
    #[error("tokenizer error {0}")]
    Tokenizer(String),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
161
}