lib.rs 2.45 KB
Newer Older
1
/// Text Generation Inference Webserver
2
3

mod infer;
4
mod queue;
Olivier Dehaene's avatar
Olivier Dehaene committed
5
pub mod server;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
6
mod validation;
Olivier Dehaene's avatar
Olivier Dehaene committed
7

8
use infer::Infer;
9
use queue::{Entry, Queue};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
10
use serde::{Deserialize, Serialize};
Olivier Dehaene's avatar
Olivier Dehaene committed
11
use validation::Validation;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
12
13
14
15
16

#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateParameters {
    #[serde(default = "default_temperature")]
    pub temperature: f32,
17
18
    #[serde(default = "default_repetition_penalty")]
    pub repetition_penalty: f32,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19
20
21
22
23
24
25
26
    #[serde(default = "default_top_k")]
    pub top_k: i32,
    #[serde(default = "default_top_p")]
    pub top_p: f32,
    #[serde(default = "default_do_sample")]
    pub do_sample: bool,
    #[serde(default = "default_max_new_tokens")]
    pub max_new_tokens: u32,
OlivierDehaene's avatar
OlivierDehaene committed
27
    #[serde(default)]
28
    pub stop: Vec<String>,
OlivierDehaene's avatar
OlivierDehaene committed
29
30
    #[serde(default)]
    pub details: bool,
31
32
    #[serde(default)]
    pub seed: Option<u64>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
33
34
35
36
37
}

fn default_temperature() -> f32 {
    1.0
}
38
39
40
fn default_repetition_penalty() -> f32 {
    1.0
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

fn default_top_k() -> i32 {
    0
}

fn default_top_p() -> f32 {
    1.0
}

fn default_do_sample() -> bool {
    false
}

fn default_max_new_tokens() -> u32 {
    20
}

fn default_parameters() -> GenerateParameters {
    GenerateParameters {
        temperature: default_temperature(),
61
        repetition_penalty: default_repetition_penalty(),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
62
63
64
65
        top_k: default_top_k(),
        top_p: default_top_p(),
        do_sample: default_do_sample(),
        max_new_tokens: default_max_new_tokens(),
66
        stop: vec![],
OlivierDehaene's avatar
OlivierDehaene committed
67
        details: false,
68
        seed: None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
69
70
71
72
73
74
75
76
77
78
    }
}

#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateRequest {
    pub inputs: String,
    #[serde(default = "default_parameters")]
    pub parameters: GenerateParameters,
}

79
80
81
#[derive(Debug, Serialize)]
pub struct Token(u32, String, f32);

OlivierDehaene's avatar
OlivierDehaene committed
82
83
84
85
#[derive(Serialize)]
pub(crate) struct Details {
    pub finish_reason: String,
    pub generated_tokens: u32,
86
    pub seed: Option<u64>,
87
88
89
90
    #[serde(skip_serializing_if = "Option::is_none")]
    pub prefill: Option<Vec<Token>>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tokens: Option<Vec<Token>>,
OlivierDehaene's avatar
OlivierDehaene committed
91
92
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
93
#[derive(Serialize)]
94
pub(crate) struct GenerateResponse {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
95
    pub generated_text: String,
OlivierDehaene's avatar
OlivierDehaene committed
96
97
    #[serde(skip_serializing_if = "Option::is_none")]
    pub details: Option<Details>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
98
}
99

100
101
102
103
104
105
106
#[derive(Serialize)]
pub(crate) struct StreamResponse {
    pub token: Token,
    pub generated_text: Option<String>,
    pub details: Option<Details>,
}

107
108
109
110
#[derive(Serialize)]
pub(crate) struct ErrorResponse {
    pub error: String,
}