lib.rs 2.45 KB
Newer Older
1
mod infer;
2
3
/// Text Generation Inference Webserver
mod queue;
Olivier Dehaene's avatar
Olivier Dehaene committed
4
pub mod server;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
5
mod validation;
Olivier Dehaene's avatar
Olivier Dehaene committed
6

7
use infer::Infer;
8
use queue::{Entry, Queue};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
9
use serde::{Deserialize, Serialize};
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use validation::Validation;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
11
12
13
14
15

#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateParameters {
    #[serde(default = "default_temperature")]
    pub temperature: f32,
16
17
    #[serde(default = "default_repetition_penalty")]
    pub repetition_penalty: f32,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
19
20
21
22
23
24
25
    #[serde(default = "default_top_k")]
    pub top_k: i32,
    #[serde(default = "default_top_p")]
    pub top_p: f32,
    #[serde(default = "default_do_sample")]
    pub do_sample: bool,
    #[serde(default = "default_max_new_tokens")]
    pub max_new_tokens: u32,
OlivierDehaene's avatar
OlivierDehaene committed
26
    #[serde(default)]
27
    pub stop: Vec<String>,
OlivierDehaene's avatar
OlivierDehaene committed
28
29
    #[serde(default)]
    pub details: bool,
30
31
    #[serde(default)]
    pub seed: Option<u64>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
32
33
34
35
36
}

fn default_temperature() -> f32 {
    1.0
}
37
38
39
fn default_repetition_penalty() -> f32 {
    1.0
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

fn default_top_k() -> i32 {
    0
}

fn default_top_p() -> f32 {
    1.0
}

fn default_do_sample() -> bool {
    false
}

fn default_max_new_tokens() -> u32 {
    20
}

fn default_parameters() -> GenerateParameters {
    GenerateParameters {
        temperature: default_temperature(),
60
        repetition_penalty: default_repetition_penalty(),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
61
62
63
64
        top_k: default_top_k(),
        top_p: default_top_p(),
        do_sample: default_do_sample(),
        max_new_tokens: default_max_new_tokens(),
65
        stop: vec![],
OlivierDehaene's avatar
OlivierDehaene committed
66
        details: false,
67
        seed: None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
68
69
70
71
72
73
74
75
76
77
    }
}

#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateRequest {
    pub inputs: String,
    #[serde(default = "default_parameters")]
    pub parameters: GenerateParameters,
}

78
79
80
#[derive(Debug, Serialize)]
pub struct Token(u32, String, f32);

OlivierDehaene's avatar
OlivierDehaene committed
81
82
83
84
#[derive(Serialize)]
pub(crate) struct Details {
    pub finish_reason: String,
    pub generated_tokens: u32,
85
    pub seed: Option<u64>,
86
87
88
89
    #[serde(skip_serializing_if = "Option::is_none")]
    pub prefill: Option<Vec<Token>>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tokens: Option<Vec<Token>>,
OlivierDehaene's avatar
OlivierDehaene committed
90
91
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
92
#[derive(Serialize)]
93
pub(crate) struct GenerateResponse {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
94
    pub generated_text: String,
OlivierDehaene's avatar
OlivierDehaene committed
95
96
    #[serde(skip_serializing_if = "Option::is_none")]
    pub details: Option<Details>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
97
}
98

99
100
101
102
103
104
105
#[derive(Serialize)]
pub(crate) struct StreamResponse {
    pub token: Token,
    pub generated_text: Option<String>,
    pub details: Option<Details>,
}

106
107
108
109
#[derive(Serialize)]
pub(crate) struct ErrorResponse {
    pub error: String,
}