server.rs 3.9 KB
Newer Older
Olivier Dehaene's avatar
Olivier Dehaene committed
1
use crate::{Batcher, ShardedClient, Validation};
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use axum::extract::Extension;
Olivier Dehaene's avatar
Olivier Dehaene committed
3
use axum::http::StatusCode;
Olivier Dehaene's avatar
Olivier Dehaene committed
4
use axum::routing::post;
Olivier Dehaene's avatar
Olivier Dehaene committed
5
use axum::{Json, Router};
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use serde::Deserialize;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use std::net::SocketAddr;
Olivier Dehaene's avatar
Olivier Dehaene committed
8
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
Olivier Dehaene committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
use tokio::time::Instant;
use tracing::instrument;

#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateParameters {
    #[serde(default = "default_temperature")]
    pub temperature: f32,
    #[serde(default = "default_top_k")]
    pub top_k: u32,
    #[serde(default = "default_top_p")]
    pub top_p: f32,
    #[serde(default = "default_do_sample")]
    pub do_sample: bool,
    #[serde(default = "default_max_new_tokens")]
    pub max_new_tokens: u32,
}

fn default_temperature() -> f32 {
    1.0
}

fn default_top_k() -> u32 {
    0
}

fn default_top_p() -> f32 {
    1.0
}

fn default_do_sample() -> bool {
    false
}

fn default_max_new_tokens() -> u32 {
    20
}

fn default_parameters() -> GenerateParameters {
    GenerateParameters {
        temperature: default_temperature(),
        top_k: default_top_k(),
        top_p: default_top_p(),
        do_sample: default_do_sample(),
        max_new_tokens: default_max_new_tokens(),
    }
}

#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateRequest {
    pub inputs: String,
    #[serde(default = "default_parameters")]
    pub parameters: GenerateParameters,
}

Olivier Dehaene's avatar
Olivier Dehaene committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#[instrument(skip(state), fields(time, time_per_token))]
async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
    let output = state
        .infer
        .infer(
            1,
            GenerateRequest {
                inputs: "liveness".to_string(),
                parameters: GenerateParameters {
                    temperature: 1.0,
                    top_k: 0,
                    top_p: 1.0,
                    do_sample: false,
                    max_new_tokens: 1,
                },
            },
        )
        .await;

    match output {
        Ok(_) => Ok(()),
        Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
    }
}

Olivier Dehaene's avatar
Olivier Dehaene committed
88
#[instrument(skip(state), fields(time, time_per_token))]
Olivier Dehaene's avatar
Olivier Dehaene committed
89
async fn generate(
Olivier Dehaene's avatar
Olivier Dehaene committed
90
    state: Extension<ServerState>,
Olivier Dehaene's avatar
Olivier Dehaene committed
91
    req: Json<GenerateRequest>,
Olivier Dehaene's avatar
Olivier Dehaene committed
92
) -> Result<Json<serde_json::Value>, StatusCode> {
Olivier Dehaene's avatar
Olivier Dehaene committed
93
94
    let start = Instant::now();

Olivier Dehaene's avatar
Olivier Dehaene committed
95
96
    let (input_length, validated_request) = match state
        .validation
Olivier Dehaene's avatar
Olivier Dehaene committed
97
        .validate(GenerateRequest {
Olivier Dehaene's avatar
Olivier Dehaene committed
98
99
100
            inputs: req.inputs.clone(),
            parameters: req.parameters.clone(),
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
101
102
        .await
    {
103
        Ok(result) => result,
Olivier Dehaene's avatar
Olivier Dehaene committed
104
        Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
105
    };
Olivier Dehaene's avatar
Olivier Dehaene committed
106

Olivier Dehaene's avatar
Olivier Dehaene committed
107
    let output = state.infer.infer(input_length, validated_request).await;
Olivier Dehaene's avatar
Olivier Dehaene committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121

    match output {
        Ok(generated_text) => {
            tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
            tracing::Span::current().record(
                "time_per_token",
                format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
            );
            tracing::info!("response: {}", generated_text);

            Ok(Json(serde_json::json!({
                "generated_text": generated_text,
            })))
        }
Olivier Dehaene's avatar
Olivier Dehaene committed
122
        Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
Olivier Dehaene's avatar
Olivier Dehaene committed
123
124
125
    }
}

Olivier Dehaene's avatar
Olivier Dehaene committed
126
127
128
129
130
131
#[derive(Clone)]
struct ServerState {
    validation: Validation,
    infer: Batcher,
}

Olivier Dehaene's avatar
Olivier Dehaene committed
132
pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
Olivier Dehaene's avatar
Olivier Dehaene committed
133
    client.clear_cache().await.expect("Unable to clear cache");
Olivier Dehaene's avatar
Olivier Dehaene committed
134
135
136
137
    tracing::info!("Connected");

    let infer = Batcher::new(client);

Olivier Dehaene's avatar
Olivier Dehaene committed
138
139
    let validation = Validation::new(tokenizer);

Olivier Dehaene's avatar
Olivier Dehaene committed
140
    let shared_state = ServerState { validation, infer };
Olivier Dehaene's avatar
Olivier Dehaene committed
141

Olivier Dehaene's avatar
Olivier Dehaene committed
142
143
144
145
146
    let app = Router::new()
        .route("/generate", post(generate))
        .layer(Extension(shared_state.clone()))
        .route("/health", post(liveness))
        .layer(Extension(shared_state.clone()));
Olivier Dehaene's avatar
Olivier Dehaene committed
147

Olivier Dehaene's avatar
Olivier Dehaene committed
148
    axum::Server::bind(&addr)
Olivier Dehaene's avatar
Olivier Dehaene committed
149
150
151
        .serve(app.into_make_service())
        .await
        .unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
152
}