lib.rs 9.03 KB
Newer Older
1
mod health;
2
/// Text Generation Inference Webserver
3
mod infer;
4
mod queue;
Olivier Dehaene's avatar
Olivier Dehaene committed
5
pub mod server;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
6
mod validation;
Olivier Dehaene's avatar
Olivier Dehaene committed
7

8
use infer::Infer;
9
use queue::{Entry, Queue};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
10
use serde::{Deserialize, Serialize};
11
use utoipa::ToSchema;
Olivier Dehaene's avatar
Olivier Dehaene committed
12
use validation::Validation;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13

14
15
/// Hub type
#[derive(Clone, Debug, Deserialize)]
16
pub struct HubModelInfo {
17
18
19
20
21
22
23
24
    #[serde(rename(deserialize = "id"))]
    pub model_id: String,
    pub sha: Option<String>,
    pub pipeline_tag: Option<String>,
}

#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info {
25
    /// Model info
26
27
28
29
    #[schema(example = "bigscience/blomm-560m")]
    pub model_id: String,
    #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
    pub model_sha: Option<String>,
30
31
32
33
    #[schema(example = "torch.float16")]
    pub model_dtype: String,
    #[schema(example = "cuda")]
    pub model_device_type: String,
34
35
    #[schema(nullable = true, example = "text-generation")]
    pub model_pipeline_tag: Option<String>,
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    /// Router Parameters
    #[schema(example = "128")]
    pub max_concurrent_requests: usize,
    #[schema(example = "2")]
    pub max_best_of: usize,
    #[schema(example = "4")]
    pub max_stop_sequences: usize,
    #[schema(example = "1024")]
    pub max_input_length: usize,
    #[schema(example = "2048")]
    pub max_total_tokens: usize,
    #[schema(example = "1.2")]
    pub waiting_served_ratio: f32,
    #[schema(example = "32000")]
    pub max_batch_total_tokens: u32,
    #[schema(example = "20")]
    pub max_waiting_tokens: usize,
    #[schema(example = "2")]
    pub validation_workers: usize,
    /// Router Info
56
57
58
59
    #[schema(example = "0.5.0")]
    pub version: &'static str,
    #[schema(nullable = true, example = "null")]
    pub sha: Option<&'static str>,
60
61
    #[schema(nullable = true, example = "null")]
    pub docker_label: Option<&'static str>,
62
63
}

64
#[derive(Clone, Debug, Deserialize, ToSchema)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
65
pub(crate) struct GenerateParameters {
66
67
68
    #[serde(default)]
    #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
    pub best_of: Option<usize>,
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    #[serde(default)]
    #[schema(
        exclusive_minimum = 0.0,
        nullable = true,
        default = "null",
        example = 0.5
    )]
    pub temperature: Option<f32>,
    #[serde(default)]
    #[schema(
        exclusive_minimum = 0.0,
        nullable = true,
        default = "null",
        example = 1.03
    )]
    pub repetition_penalty: Option<f32>,
    #[serde(default)]
    #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
    pub top_k: Option<i32>,
    #[serde(default)]
    #[schema(
        exclusive_minimum = 0.0,
        maximum = 1.0,
        nullable = true,
        default = "null",
        example = 0.95
    )]
    pub top_p: Option<f32>,
97
    #[serde(default)]
98
99
100
101
102
103
104
105
106
    #[schema(
        exclusive_minimum = 0.0,
        maximum = 1.0,
        nullable = true,
        default = "null",
        example = 0.95
    )]
    pub typical_p: Option<f32>,
    #[serde(default)]
107
    #[schema(default = "false", example = true)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
108
109
    pub do_sample: bool,
    #[serde(default = "default_max_new_tokens")]
110
111
    #[schema(nullable = true, default = "null", example = "20")]
    pub max_new_tokens: Option<u32>,
OlivierDehaene's avatar
OlivierDehaene committed
112
    #[serde(default)]
113
    #[schema(nullable = true, default = "null", example = false)]
114
115
    pub return_full_text: Option<bool>,
    #[serde(default)]
116
    #[schema(inline, max_items = 4, example = json ! (["photographer"]))]
117
    pub stop: Vec<String>,
OlivierDehaene's avatar
OlivierDehaene committed
118
    #[serde(default)]
119
    #[schema(nullable = true, default = "null", example = "null")]
120
121
    pub truncate: Option<usize>,
    #[serde(default)]
122
123
124
    #[schema(default = "false", example = true)]
    pub watermark: bool,
    #[serde(default)]
125
    #[schema(default = "true")]
OlivierDehaene's avatar
OlivierDehaene committed
126
    pub details: bool,
127
    #[serde(default)]
128
129
130
    #[schema(default = "true")]
    pub decoder_input_details: bool,
    #[serde(default)]
131
132
133
134
135
136
    #[schema(
        exclusive_minimum = 0,
        nullable = true,
        default = "null",
        example = "null"
    )]
137
    pub seed: Option<u64>,
Nicolas Patry's avatar
Nicolas Patry committed
138
139
140
    #[serde(default)]
    #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
    pub top_n_tokens: Option<u32>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
141
142
}

143
144
fn default_max_new_tokens() -> Option<u32> {
    None
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
145
146
147
148
}

fn default_parameters() -> GenerateParameters {
    GenerateParameters {
149
        best_of: None,
150
151
152
153
        temperature: None,
        repetition_penalty: None,
        top_k: None,
        top_p: None,
154
        typical_p: None,
155
        do_sample: false,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
156
        max_new_tokens: default_max_new_tokens(),
157
        return_full_text: None,
158
        stop: Vec::new(),
159
        truncate: None,
160
        watermark: false,
OlivierDehaene's avatar
OlivierDehaene committed
161
        details: false,
162
        decoder_input_details: false,
163
        seed: None,
Nicolas Patry's avatar
Nicolas Patry committed
164
        top_n_tokens: None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
165
166
167
    }
}

168
#[derive(Clone, Debug, Deserialize, ToSchema)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
169
pub(crate) struct GenerateRequest {
170
    #[schema(example = "My name is Olivier and I")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
171
172
173
174
175
    pub inputs: String,
    #[serde(default = "default_parameters")]
    pub parameters: GenerateParameters,
}

176
177
178
179
180
181
182
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct CompatGenerateRequest {
    #[schema(example = "My name is Olivier and I")]
    pub inputs: String,
    #[serde(default = "default_parameters")]
    pub parameters: GenerateParameters,
    #[serde(default)]
OlivierDehaene's avatar
OlivierDehaene committed
183
    #[schema(default = "false")]
184
185
186
187
188
189
190
191
192
193
194
195
    pub stream: bool,
}

impl From<CompatGenerateRequest> for GenerateRequest {
    fn from(req: CompatGenerateRequest) -> Self {
        Self {
            inputs: req.inputs,
            parameters: req.parameters,
        }
    }
}

196
197
198
199
200
201
#[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken {
    #[schema(example = 0)]
    id: u32,
    #[schema(example = "test")]
    text: String,
202
    #[schema(nullable = true, example = - 0.34)]
203
204
205
    logprob: f32,
}

206
207
208
209
210
211
#[derive(Debug, Serialize, ToSchema)]
pub struct Token {
    #[schema(example = 0)]
    id: u32,
    #[schema(example = "test")]
    text: String,
212
    #[schema(nullable = true, example = - 0.34)]
213
    logprob: f32,
214
215
    #[schema(example = "false")]
    special: bool,
216
217
218
219
220
221
222
223
224
225
226
227
228
}

#[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason {
    #[schema(rename = "length")]
    Length,
    #[serde(rename = "eos_token")]
    #[schema(rename = "eos_token")]
    EndOfSequenceToken,
    #[schema(rename = "stop_sequence")]
    StopSequence,
}
229

230
231
232
233
234
235
236
237
238
239
240
241
#[derive(Serialize, ToSchema)]
pub(crate) struct BestOfSequence {
    #[schema(example = "test")]
    pub generated_text: String,
    #[schema(example = "length")]
    pub finish_reason: FinishReason,
    #[schema(example = 1)]
    pub generated_tokens: u32,
    #[schema(nullable = true, example = 42)]
    pub seed: Option<u64>,
    pub prefill: Vec<PrefillToken>,
    pub tokens: Vec<Token>,
Nicolas Patry's avatar
Nicolas Patry committed
242
243
    #[serde(skip_serializing_if = "Vec::is_empty")]
    pub top_tokens: Vec<Vec<Token>>,
244
245
}

246
#[derive(Serialize, ToSchema)]
OlivierDehaene's avatar
OlivierDehaene committed
247
pub(crate) struct Details {
248
249
250
    #[schema(example = "length")]
    pub finish_reason: FinishReason,
    #[schema(example = 1)]
OlivierDehaene's avatar
OlivierDehaene committed
251
    pub generated_tokens: u32,
252
    #[schema(nullable = true, example = 42)]
253
    pub seed: Option<u64>,
254
255
    pub prefill: Vec<PrefillToken>,
    pub tokens: Vec<Token>,
256
257
    #[serde(skip_serializing_if = "Option::is_none")]
    pub best_of_sequences: Option<Vec<BestOfSequence>>,
Nicolas Patry's avatar
Nicolas Patry committed
258
259
    #[serde(skip_serializing_if = "Vec::is_empty")]
    pub top_tokens: Vec<Vec<Token>>,
OlivierDehaene's avatar
OlivierDehaene committed
260
261
}

262
#[derive(Serialize, ToSchema)]
263
pub(crate) struct GenerateResponse {
264
    #[schema(example = "test")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
265
    pub generated_text: String,
OlivierDehaene's avatar
OlivierDehaene committed
266
267
    #[serde(skip_serializing_if = "Option::is_none")]
    pub details: Option<Details>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
268
}
269

270
271
272
273
274
275
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails {
    #[schema(example = "length")]
    pub finish_reason: FinishReason,
    #[schema(example = 1)]
    pub generated_tokens: u32,
276
    #[schema(nullable = true, example = 42)]
277
278
279
280
    pub seed: Option<u64>,
}

#[derive(Serialize, ToSchema)]
281
282
pub(crate) struct StreamResponse {
    pub token: Token,
Nicolas Patry's avatar
Nicolas Patry committed
283
284
    #[serde(skip_serializing_if = "Vec::is_empty")]
    pub top_tokens: Vec<Token>,
285
    #[schema(nullable = true, default = "null", example = "test")]
286
    pub generated_text: Option<String>,
287
288
    #[schema(nullable = true, default = "null")]
    pub details: Option<StreamDetails>,
289
290
}

291
#[derive(Serialize, ToSchema)]
292
293
pub(crate) struct ErrorResponse {
    pub error: String,
294
    pub error_type: String,
295
}
296
297

#[cfg(test)]
298
mod tests {
299
300
301
    use std::io::Write;
    use tokenizers::Tokenizer;

302
    pub(crate) async fn get_tokenizer() -> Tokenizer {
303
304
        let filename = std::path::Path::new("tokenizer.json");
        if !filename.exists() {
305
306
307
308
309
310
            let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
                .await
                .unwrap()
                .bytes()
                .await
                .unwrap();
311
312
            let tmp_filename = "tokenizer.json.temp";
            let mut file = std::fs::File::create(tmp_filename).unwrap();
313
            file.write_all(&content).unwrap();
314
315
316
317
            // Re-check if another process has written this file maybe.
            if !filename.exists() {
                std::fs::rename(tmp_filename, filename).unwrap()
            }
318
319
320
321
        }
        Tokenizer::from_file("tokenizer.json").unwrap()
    }
}