"torchvision/extension.py" did not exist on "1f2c15f23643a62318c3c803cc65b652c2da198e"
lib.rs 22.4 KB
Newer Older
1
mod health;
2
/// Text Generation Inference Webserver
3
mod infer;
4
mod queue;
Olivier Dehaene's avatar
Olivier Dehaene committed
5
pub mod server;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
6
mod validation;
Olivier Dehaene's avatar
Olivier Dehaene committed
7

8
use infer::{Infer, InferError, InferStreamResponse};
9
use queue::{Entry, Queue};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
10
use serde::{Deserialize, Serialize};
11
12
use tokio::sync::OwnedSemaphorePermit;
use tokio_stream::wrappers::UnboundedReceiverStream;
13
use utoipa::ToSchema;
Olivier Dehaene's avatar
Olivier Dehaene committed
14
use validation::Validation;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
15

16
17
18
19
20
21
22
/// Type alias for generation responses
pub(crate) type GenerateStreamResponse = (
    OwnedSemaphorePermit,
    u32, // input_length
    UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
);

23
24
/// Hub type
#[derive(Clone, Debug, Deserialize)]
25
pub struct HubModelInfo {
26
27
28
29
30
31
    #[serde(rename(deserialize = "id"))]
    pub model_id: String,
    pub sha: Option<String>,
    pub pipeline_tag: Option<String>,
}

32
33
34
#[derive(Clone, Deserialize, Default)]
pub struct HubTokenizerConfig {
    pub chat_template: Option<String>,
35
    #[serde(deserialize_with = "token_serde::deserialize")]
36
    pub bos_token: Option<String>,
37
    #[serde(deserialize_with = "token_serde::deserialize")]
38
    pub eos_token: Option<String>,
39
40
41
}

impl HubTokenizerConfig {
42
    pub fn from_file(filename: &std::path::Path) -> Self {
43
44
45
46
47
        let content = std::fs::read_to_string(filename).unwrap();
        serde_json::from_str(&content).unwrap_or_default()
    }
}

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
mod token_serde {
    use super::*;
    use serde::de;
    use serde::Deserializer;
    use serde_json::Value;

    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
    where
        D: Deserializer<'de>,
    {
        let value = Value::deserialize(deserializer)?;

        match value {
            Value::String(s) => Ok(Some(s)),
            Value::Object(map) => {
                if let Some(content) = map.get("content").and_then(|v| v.as_str()) {
                    Ok(Some(content.to_string()))
                } else {
                    Err(de::Error::custom(
                        "content key not found in structured token",
                    ))
                }
            }
            _ => Err(de::Error::custom("invalid token format")),
        }
    }
}

76
77
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info {
78
    /// Model info
79
80
81
82
    #[schema(example = "bigscience/blomm-560m")]
    pub model_id: String,
    #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
    pub model_sha: Option<String>,
83
84
85
86
    #[schema(example = "torch.float16")]
    pub model_dtype: String,
    #[schema(example = "cuda")]
    pub model_device_type: String,
87
88
    #[schema(nullable = true, example = "text-generation")]
    pub model_pipeline_tag: Option<String>,
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    /// Router Parameters
    #[schema(example = "128")]
    pub max_concurrent_requests: usize,
    #[schema(example = "2")]
    pub max_best_of: usize,
    #[schema(example = "4")]
    pub max_stop_sequences: usize,
    #[schema(example = "1024")]
    pub max_input_length: usize,
    #[schema(example = "2048")]
    pub max_total_tokens: usize,
    #[schema(example = "1.2")]
    pub waiting_served_ratio: f32,
    #[schema(example = "32000")]
    pub max_batch_total_tokens: u32,
    #[schema(example = "20")]
    pub max_waiting_tokens: usize,
106
107
    #[schema(nullable = true, example = "null")]
    pub max_batch_size: Option<usize>,
108
109
110
    #[schema(example = "2")]
    pub validation_workers: usize,
    /// Router Info
111
112
113
114
    #[schema(example = "0.5.0")]
    pub version: &'static str,
    #[schema(nullable = true, example = "null")]
    pub sha: Option<&'static str>,
115
116
    #[schema(nullable = true, example = "null")]
    pub docker_label: Option<&'static str>,
117
118
}

119
#[derive(Clone, Debug, Deserialize, ToSchema)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
120
pub(crate) struct GenerateParameters {
121
122
123
    #[serde(default)]
    #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
    pub best_of: Option<usize>,
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    #[serde(default)]
    #[schema(
        exclusive_minimum = 0.0,
        nullable = true,
        default = "null",
        example = 0.5
    )]
    pub temperature: Option<f32>,
    #[serde(default)]
    #[schema(
        exclusive_minimum = 0.0,
        nullable = true,
        default = "null",
        example = 1.03
    )]
    pub repetition_penalty: Option<f32>,
    #[serde(default)]
141
142
143
144
145
146
147
148
    #[schema(
        exclusive_minimum = -2.0,
        nullable = true,
        default = "null",
        example = 0.1
    )]
    pub frequency_penalty: Option<f32>,
    #[serde(default)]
149
150
151
152
153
154
155
156
157
158
159
    #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
    pub top_k: Option<i32>,
    #[serde(default)]
    #[schema(
        exclusive_minimum = 0.0,
        maximum = 1.0,
        nullable = true,
        default = "null",
        example = 0.95
    )]
    pub top_p: Option<f32>,
160
    #[serde(default)]
161
162
163
164
165
166
167
168
169
    #[schema(
        exclusive_minimum = 0.0,
        maximum = 1.0,
        nullable = true,
        default = "null",
        example = 0.95
    )]
    pub typical_p: Option<f32>,
    #[serde(default)]
170
    #[schema(default = "false", example = true)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
171
172
    pub do_sample: bool,
    #[serde(default = "default_max_new_tokens")]
173
    #[schema(nullable = true, default = "100", example = "20")]
174
    pub max_new_tokens: Option<u32>,
OlivierDehaene's avatar
OlivierDehaene committed
175
    #[serde(default)]
176
    #[schema(nullable = true, default = "null", example = false)]
177
178
    pub return_full_text: Option<bool>,
    #[serde(default)]
179
    #[schema(inline, max_items = 4, example = json ! (["photographer"]))]
180
    pub stop: Vec<String>,
OlivierDehaene's avatar
OlivierDehaene committed
181
    #[serde(default)]
182
    #[schema(nullable = true, default = "null", example = "null")]
183
184
    pub truncate: Option<usize>,
    #[serde(default)]
185
186
187
    #[schema(default = "false", example = true)]
    pub watermark: bool,
    #[serde(default)]
188
    #[schema(default = "true")]
OlivierDehaene's avatar
OlivierDehaene committed
189
    pub details: bool,
190
    #[serde(default)]
191
192
193
    #[schema(default = "true")]
    pub decoder_input_details: bool,
    #[serde(default)]
194
195
196
197
198
199
    #[schema(
        exclusive_minimum = 0,
        nullable = true,
        default = "null",
        example = "null"
    )]
200
    pub seed: Option<u64>,
Nicolas Patry's avatar
Nicolas Patry committed
201
202
203
    #[serde(default)]
    #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
    pub top_n_tokens: Option<u32>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
204
205
}

206
fn default_max_new_tokens() -> Option<u32> {
207
    Some(100)
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
208
209
210
211
}

fn default_parameters() -> GenerateParameters {
    GenerateParameters {
212
        best_of: None,
213
214
        temperature: None,
        repetition_penalty: None,
215
        frequency_penalty: None,
216
217
        top_k: None,
        top_p: None,
218
        typical_p: None,
219
        do_sample: true,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
220
        max_new_tokens: default_max_new_tokens(),
221
        return_full_text: None,
222
        stop: Vec::new(),
223
        truncate: None,
224
        watermark: false,
OlivierDehaene's avatar
OlivierDehaene committed
225
        details: false,
226
        decoder_input_details: false,
227
        seed: None,
Nicolas Patry's avatar
Nicolas Patry committed
228
        top_n_tokens: None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
229
230
231
    }
}

232
#[derive(Clone, Deserialize, Serialize, ToSchema)]
233
234
235
pub(crate) struct ChatCompletion {
    pub id: String,
    pub object: String,
236
    #[schema(example = "1706270835")]
237
    pub created: u64,
238
    #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
239
240
241
242
243
244
    pub model: String,
    pub system_fingerprint: String,
    pub choices: Vec<ChatCompletionComplete>,
    pub usage: Usage,
}

245
#[derive(Clone, Deserialize, Serialize, ToSchema)]
246
247
248
pub(crate) struct ChatCompletionComplete {
    pub index: u32,
    pub message: Message,
249
    pub logprobs: Option<ChatCompletionLogprobs>,
250
251
252
    pub finish_reason: String,
}

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionLogprobs {
    content: Vec<ChatCompletionLogprob>,
}

impl From<(Token, Vec<Token>)> for ChatCompletionLogprobs {
    fn from(value: (Token, Vec<Token>)) -> Self {
        let (token, top_tokens) = value;

        Self {
            content: vec![ChatCompletionLogprob {
                token: token.text,
                logprob: token.logprob,
                top_logprobs: top_tokens
                    .into_iter()
                    .map(|t| ChatCompletionTopLogprob {
                        token: t.text,
                        logprob: t.logprob,
                    })
                    .collect(),
            }],
        }
    }
}

impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
    fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {
        let (tokens, top_tokens) = value;
        Self {
            content: tokens
                .into_iter()
                .zip(top_tokens)
                .map(|(t, top_t)| ChatCompletionLogprob {
                    token: t.text,
                    logprob: t.logprob,
                    top_logprobs: top_t
                        .into_iter()
                        .map(|t| ChatCompletionTopLogprob {
                            token: t.text,
                            logprob: t.logprob,
                        })
                        .collect(),
                })
                .collect(),
        }
    }
}

#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionLogprob {
    token: String,
    logprob: f32,
    top_logprobs: Vec<ChatCompletionTopLogprob>,
}

#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionTopLogprob {
    token: String,
    logprob: f32,
}

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct Usage {
    pub prompt_tokens: u32,
    pub completion_tokens: u32,
    pub total_tokens: u32,
}

impl ChatCompletion {
    pub(crate) fn new(
        model: String,
        system_fingerprint: String,
        output: String,
        created: u64,
        details: Details,
        return_logprobs: bool,
    ) -> Self {
        Self {
            id: String::new(),
            object: "text_completion".into(),
            created,
            model,
            system_fingerprint,
            choices: vec![ChatCompletionComplete {
                index: 0,
                message: Message {
                    role: "assistant".into(),
                    content: output,
                },
                logprobs: return_logprobs
343
                    .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
344
345
346
347
348
349
350
351
352
353
354
                finish_reason: details.finish_reason.to_string(),
            }],
            usage: Usage {
                prompt_tokens: details.prefill.len() as u32,
                completion_tokens: details.generated_tokens,
                total_tokens: details.prefill.len() as u32 + details.generated_tokens,
            },
        }
    }
}

355
#[derive(Clone, Deserialize, Serialize, ToSchema)]
356
357
358
pub(crate) struct ChatCompletionChunk {
    pub id: String,
    pub object: String,
359
    #[schema(example = "1706270978")]
360
    pub created: u64,
361
    #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
362
363
364
365
366
    pub model: String,
    pub system_fingerprint: String,
    pub choices: Vec<ChatCompletionChoice>,
}

367
#[derive(Clone, Deserialize, Serialize, ToSchema)]
368
369
370
pub(crate) struct ChatCompletionChoice {
    pub index: u32,
    pub delta: ChatCompletionDelta,
371
    pub logprobs: Option<ChatCompletionLogprobs>,
372
373
374
    pub finish_reason: Option<String>,
}

375
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
376
pub(crate) struct ChatCompletionDelta {
377
    #[schema(example = "user")]
378
    pub role: String,
379
    #[schema(example = "What is Deep Learning?")]
380
381
382
383
384
385
386
387
388
389
    pub content: String,
}

impl ChatCompletionChunk {
    pub(crate) fn new(
        model: String,
        system_fingerprint: String,
        delta: String,
        created: u64,
        index: u32,
390
        logprobs: Option<ChatCompletionLogprobs>,
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
        finish_reason: Option<String>,
    ) -> Self {
        Self {
            id: String::new(),
            object: "text_completion".to_string(),
            created,
            model,
            system_fingerprint,
            choices: vec![ChatCompletionChoice {
                index,
                delta: ChatCompletionDelta {
                    role: "assistant".to_string(),
                    content: delta,
                },
                logprobs,
                finish_reason,
            }],
        }
    }
}

fn default_request_messages() -> Vec<Message> {
    vec![Message {
        role: "user".to_string(),
        content: "My name is David and I".to_string(),
    }]
}

#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct ChatRequest {
    /// UNUSED
422
    #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
423
    /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
424
425
    pub model: String,
    /* NOTE: UNUSED */
426
427
428
429
430
431
432
    /// A list of messages comprising the conversation so far.
    #[serde(default = "default_request_messages")]
    pub messages: Vec<Message>,

    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
    /// decreasing the model's likelihood to repeat the same line verbatim.
    #[serde(default)]
433
    #[schema(example = "1.0")]
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    pub frequency_penalty: Option<f32>,

    /// UNUSED
    /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
    /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
    /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
    /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
    /// result in a ban or exclusive selection of the relevant token.
    #[serde(default)]
    pub logit_bias: Option<Vec<f32>>,

    /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
    /// output token returned in the content of message.
    #[serde(default)]
448
    #[schema(example = "false")]
449
450
451
452
453
    pub logprobs: Option<bool>,

    /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
    /// an associated log probability. logprobs must be set to true if this parameter is used.
    #[serde(default)]
454
    #[schema(example = "5")]
455
456
457
458
    pub top_logprobs: Option<u32>,

    /// The maximum number of tokens that can be generated in the chat completion.
    #[serde(default)]
459
    #[schema(example = "32")]
460
461
462
463
464
465
    pub max_tokens: Option<u32>,

    /// UNUSED
    /// How many chat completion choices to generate for each input message. Note that you will be charged based on the
    /// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
    #[serde(default)]
466
    #[schema(nullable = true, example = "2")]
467
468
469
470
471
    pub n: Option<u32>,

    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
    /// increasing the model's likelihood to talk about new topics
    #[serde(default)]
472
    #[schema(nullable = true, example = 0.1)]
473
474
475
476
477
478
479
    pub presence_penalty: Option<f32>,

    #[serde(default = "bool::default")]
    pub stream: bool,

    #[schema(nullable = true, example = 42)]
    pub seed: Option<u64>,
480
481
482
483
484
485

    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
    /// lower values like 0.2 will make it more focused and deterministic.
    ///
    /// We generally recommend altering this or `top_p` but not both.
    #[serde(default)]
486
    #[schema(nullable = true, example = 1.0)]
487
488
489
490
491
    pub temperature: Option<f32>,

    /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
    /// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
    #[serde(default)]
492
    #[schema(nullable = true, example = 0.95)]
493
    pub top_p: Option<f32>,
494
495
}

496
497
498
499
500
#[derive(Clone, Serialize, Deserialize)]
pub(crate) struct ChatTemplateInputs<'a> {
    messages: Vec<Message>,
    bos_token: Option<&'a str>,
    eos_token: Option<&'a str>,
501
    add_generation_prompt: bool,
502
503
}

504
505
506
507
508
509
510
511
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct Message {
    #[schema(example = "user")]
    pub role: String,
    #[schema(example = "My name is David and I")]
    pub content: String,
}

512
#[derive(Clone, Debug, Deserialize, ToSchema)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
513
pub(crate) struct GenerateRequest {
514
    #[schema(example = "My name is Olivier and I")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
515
516
517
518
519
    pub inputs: String,
    #[serde(default = "default_parameters")]
    pub parameters: GenerateParameters,
}

520
521
522
523
524
525
526
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct CompatGenerateRequest {
    #[schema(example = "My name is Olivier and I")]
    pub inputs: String,
    #[serde(default = "default_parameters")]
    pub parameters: GenerateParameters,
    #[serde(default)]
OlivierDehaene's avatar
OlivierDehaene committed
527
    #[schema(default = "false")]
528
529
530
531
532
533
534
535
536
537
538
539
    pub stream: bool,
}

impl From<CompatGenerateRequest> for GenerateRequest {
    fn from(req: CompatGenerateRequest) -> Self {
        Self {
            inputs: req.inputs,
            parameters: req.parameters,
        }
    }
}

540
541
542
543
544
545
#[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken {
    #[schema(example = 0)]
    id: u32,
    #[schema(example = "test")]
    text: String,
546
    #[schema(nullable = true, example = - 0.34)]
547
548
549
    logprob: f32,
}

550
#[derive(Debug, Serialize, ToSchema, Clone)]
551
552
553
554
555
pub struct Token {
    #[schema(example = 0)]
    id: u32,
    #[schema(example = "test")]
    text: String,
556
    #[schema(nullable = true, example = - 0.34)]
557
    logprob: f32,
558
559
    #[schema(example = "false")]
    special: bool,
560
561
}

562
563
564
565
566
567
568
569
570
571
572
573
#[derive(Debug, Serialize, ToSchema)]
pub struct SimpleToken {
    #[schema(example = 0)]
    id: u32,
    #[schema(example = "test")]
    text: String,
    #[schema(example = 0)]
    start: usize,
    #[schema(example = 2)]
    stop: usize,
}

574
575
#[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
576
#[schema(example = "Length")]
577
578
579
580
581
582
583
584
585
pub(crate) enum FinishReason {
    #[schema(rename = "length")]
    Length,
    #[serde(rename = "eos_token")]
    #[schema(rename = "eos_token")]
    EndOfSequenceToken,
    #[schema(rename = "stop_sequence")]
    StopSequence,
}
586

587
588
589
590
591
592
593
594
595
596
impl std::fmt::Display for FinishReason {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            FinishReason::Length => write!(f, "length"),
            FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
            FinishReason::StopSequence => write!(f, "stop_sequence"),
        }
    }
}

597
598
599
600
601
602
603
604
605
606
607
608
#[derive(Serialize, ToSchema)]
pub(crate) struct BestOfSequence {
    #[schema(example = "test")]
    pub generated_text: String,
    #[schema(example = "length")]
    pub finish_reason: FinishReason,
    #[schema(example = 1)]
    pub generated_tokens: u32,
    #[schema(nullable = true, example = 42)]
    pub seed: Option<u64>,
    pub prefill: Vec<PrefillToken>,
    pub tokens: Vec<Token>,
Nicolas Patry's avatar
Nicolas Patry committed
609
610
    #[serde(skip_serializing_if = "Vec::is_empty")]
    pub top_tokens: Vec<Vec<Token>>,
611
612
}

613
#[derive(Serialize, ToSchema)]
OlivierDehaene's avatar
OlivierDehaene committed
614
pub(crate) struct Details {
615
616
617
    #[schema(example = "length")]
    pub finish_reason: FinishReason,
    #[schema(example = 1)]
OlivierDehaene's avatar
OlivierDehaene committed
618
    pub generated_tokens: u32,
619
    #[schema(nullable = true, example = 42)]
620
    pub seed: Option<u64>,
621
622
    pub prefill: Vec<PrefillToken>,
    pub tokens: Vec<Token>,
623
624
    #[serde(skip_serializing_if = "Option::is_none")]
    pub best_of_sequences: Option<Vec<BestOfSequence>>,
Nicolas Patry's avatar
Nicolas Patry committed
625
626
    #[serde(skip_serializing_if = "Vec::is_empty")]
    pub top_tokens: Vec<Vec<Token>>,
OlivierDehaene's avatar
OlivierDehaene committed
627
628
}

629
#[derive(Serialize, ToSchema)]
630
pub(crate) struct GenerateResponse {
631
    #[schema(example = "test")]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
632
    pub generated_text: String,
OlivierDehaene's avatar
OlivierDehaene committed
633
634
    #[serde(skip_serializing_if = "Option::is_none")]
    pub details: Option<Details>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
635
}
636

637
638
639
640
#[derive(Serialize, ToSchema)]
#[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);

641
642
643
644
645
646
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails {
    #[schema(example = "length")]
    pub finish_reason: FinishReason,
    #[schema(example = 1)]
    pub generated_tokens: u32,
647
    #[schema(nullable = true, example = 42)]
648
649
650
651
    pub seed: Option<u64>,
}

#[derive(Serialize, ToSchema)]
652
pub(crate) struct StreamResponse {
653
    pub index: u32,
654
    pub token: Token,
Nicolas Patry's avatar
Nicolas Patry committed
655
656
    #[serde(skip_serializing_if = "Vec::is_empty")]
    pub top_tokens: Vec<Token>,
657
    #[schema(nullable = true, default = "null", example = "test")]
658
    pub generated_text: Option<String>,
659
660
    #[schema(nullable = true, default = "null")]
    pub details: Option<StreamDetails>,
661
662
}

663
#[derive(Serialize, ToSchema)]
664
665
pub(crate) struct ErrorResponse {
    pub error: String,
666
    pub error_type: String,
667
}
668
669

#[cfg(test)]
670
mod tests {
671
672
    use super::*;

673
674
    use tokenizers::Tokenizer;

675
    pub(crate) async fn get_tokenizer() -> Tokenizer {
676
677
678
679
        let api = hf_hub::api::sync::Api::new().unwrap();
        let repo = api.model("gpt2".to_string());
        let filename = repo.get("tokenizer.json").unwrap();
        Tokenizer::from_file(filename).unwrap()
680
    }
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733

    #[test]
    fn test_hub_nested_tokens_tokenizer_config() {
        // this is a subset of the tokenizer.json file
        // in this case we expect the tokens to be encoded as simple strings
        let json_content = r#"{
            "chat_template": "test",
            "bos_token": "<|begin▁of▁sentence|>",
            "eos_token": "<|end▁of▁sentence|>"
        }"#;

        let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();

        // check that we successfully parsed the tokens
        assert_eq!(config.chat_template, Some("test".to_string()));
        assert_eq!(
            config.bos_token,
            Some("<|begin▁of▁sentence|>".to_string())
        );
        assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));

        // in this case we expect the tokens to be encoded as structured tokens
        // we want the content of the structured token
        let json_content = r#"{
            "chat_template": "test",
            "bos_token": {
              "__type": "AddedToken",
              "content": "<|begin▁of▁sentence|>",
              "lstrip": false,
              "normalized": true,
              "rstrip": false,
              "single_word": false
            },
            "eos_token": {
              "__type": "AddedToken",
              "content": "<|end▁of▁sentence|>",
              "lstrip": false,
              "normalized": true,
              "rstrip": false,
              "single_word": false
            }
        }"#;

        let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();

        // check that we successfully parsed the tokens
        assert_eq!(config.chat_template, Some("test".to_string()));
        assert_eq!(
            config.bos_token,
            Some("<|begin▁of▁sentence|>".to_string())
        );
        assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
    }
734
}