validation.rs 1.94 KB
Newer Older
Olivier Dehaene's avatar
Olivier Dehaene committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
use crate::server::GenerateRequest;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};

#[derive(Debug)]
pub struct ValidationError {}

type ValidationRequest = (
    GenerateRequest,
    oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
);

#[derive(Debug, Clone)]
pub(crate) struct Validation {
    sender: mpsc::Sender<ValidationRequest>,
}

impl Validation {
    pub(crate) fn new(tokenizer: Tokenizer) -> Self {
        let (validation_sender, validation_receiver) = mpsc::channel(128);

        tokio::spawn(validation_task(tokenizer, validation_receiver));

        Self {
            sender: validation_sender,
        }
    }

    pub(crate) async fn validate(
        &self,
        request: GenerateRequest,
    ) -> Result<(usize, GenerateRequest), ValidationError> {
        let (sender, receiver) = oneshot::channel();
        self.sender.send((request, sender)).await.unwrap();
        receiver.await.unwrap()
    }
}

async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<ValidationRequest>) {
    while let Some((request, response_tx)) = receiver.recv().await {
        if request.parameters.temperature < 0.0 {
            response_tx.send(Err(ValidationError {})).unwrap_or(());
            continue;
        }
        if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
            response_tx.send(Err(ValidationError {})).unwrap_or(());
            continue;
        }
        if request.parameters.max_new_tokens > 512 {
            response_tx.send(Err(ValidationError {})).unwrap_or(());
            continue;
        }

        let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap();
        let input_length = inputs.len();

        if input_length > 512 {
            response_tx.send(Err(ValidationError {})).unwrap_or(());
            continue;
        }

        response_tx.send(Ok((input_length, request))).unwrap_or(());
    }
    println!("drop here");
}