server.rs 4.76 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
2
3
use crate::{
    Batcher, GenerateParameters, GenerateRequest, GenerateResponse, GeneratedText, Validation,
};
Olivier Dehaene's avatar
Olivier Dehaene committed
4
use axum::extract::Extension;
Olivier Dehaene's avatar
Olivier Dehaene committed
5
use axum::http::StatusCode;
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use axum::routing::{get, post};
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use axum::{Json, Router};
Olivier Dehaene's avatar
Olivier Dehaene committed
8
use bloom_inference_client::ShardedClient;
Olivier Dehaene's avatar
Olivier Dehaene committed
9
use std::net::SocketAddr;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
10
11
use std::sync::Arc;
use std::time::Duration;
Olivier Dehaene's avatar
Olivier Dehaene committed
12
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
14
use tokio::signal;
use tokio::sync::Semaphore;
Olivier Dehaene's avatar
Olivier Dehaene committed
15
16
17
use tokio::time::Instant;
use tracing::instrument;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
19
20
21
22
23
// Server shared state
#[derive(Clone)]
struct ServerState {
    validation: Validation,
    batcher: Batcher,
    limit_concurrent_requests: Arc<Semaphore>,
Olivier Dehaene's avatar
Olivier Dehaene committed
24
25
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
26
/// Health check method
Olivier Dehaene's avatar
Olivier Dehaene committed
27
#[instrument(skip(state), fields(time, time_per_token))]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
    // TODO: while this is the best health check we can do, it is a bit on the heavy side and might
    //       be a bit too slow for a health check.
    //       What we should do instead if check if the gRPC channels are still healthy.

    // Limit concurrent requests by acquiring a permit from the semaphore
    let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
        (
            StatusCode::TOO_MANY_REQUESTS,
            "Model is overloaded".to_string(),
        )
    })?;

    // Send a small inference request
Olivier Dehaene's avatar
Olivier Dehaene committed
42
    state
Olivier Dehaene's avatar
Olivier Dehaene committed
43
        .batcher
Olivier Dehaene's avatar
Olivier Dehaene committed
44
45
46
47
48
49
50
51
52
53
54
55
56
        .infer(
            1,
            GenerateRequest {
                inputs: "liveness".to_string(),
                parameters: GenerateParameters {
                    temperature: 1.0,
                    top_k: 0,
                    top_p: 1.0,
                    do_sample: false,
                    max_new_tokens: 1,
                },
            },
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
57
58
        .await?;
    Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
59
60
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
61
/// Generate method
Olivier Dehaene's avatar
Olivier Dehaene committed
62
#[instrument(skip(state), fields(time, time_per_token))]
Olivier Dehaene's avatar
Olivier Dehaene committed
63
async fn generate(
Olivier Dehaene's avatar
Olivier Dehaene committed
64
    state: Extension<ServerState>,
Olivier Dehaene's avatar
Olivier Dehaene committed
65
    req: Json<GenerateRequest>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
66
) -> Result<Json<GenerateResponse>, (StatusCode, String)> {
Olivier Dehaene's avatar
Olivier Dehaene committed
67
    let start = Instant::now();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
68
69
70
71
72
73
74
    // Limit concurrent requests by acquiring a permit from the semaphore
    let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
        (
            StatusCode::TOO_MANY_REQUESTS,
            "Model is overloaded".to_string(),
        )
    })?;
Olivier Dehaene's avatar
Olivier Dehaene committed
75

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
76
    // Validate request
Olivier Dehaene's avatar
Olivier Dehaene committed
77
    let (input_length, validated_request) = state
Olivier Dehaene's avatar
Olivier Dehaene committed
78
        .validation
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
79
        // FIXME: can't we get rid of the cloning here??
Olivier Dehaene's avatar
Olivier Dehaene committed
80
        .validate(GenerateRequest {
Olivier Dehaene's avatar
Olivier Dehaene committed
81
82
83
            inputs: req.inputs.clone(),
            parameters: req.parameters.clone(),
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
84
85
        .await?;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
86
    // Inference
Olivier Dehaene's avatar
Olivier Dehaene committed
87
    let generated_text = state.batcher.infer(input_length, validated_request).await?;
Olivier Dehaene's avatar
Olivier Dehaene committed
88

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
89
    // Tracing metadata
Olivier Dehaene's avatar
Olivier Dehaene committed
90
91
92
93
94
95
96
    tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
    tracing::Span::current().record(
        "time_per_token",
        format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
    );
    tracing::info!("response: {}", generated_text);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
97
98
99
    // Send response
    let response = vec![GeneratedText { generated_text }];
    Ok(Json(response))
Olivier Dehaene's avatar
Olivier Dehaene committed
100
101
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
    max_concurrent_requests: usize,
    max_input_length: usize,
    max_batch_size: usize,
    max_waiting_time: Duration,
    client: ShardedClient,
    tokenizer: Tokenizer,
    validation_workers: usize,
    addr: SocketAddr,
) {
    // Create state
    let batcher = Batcher::new(client, max_batch_size, max_waiting_time);
    let validation = Validation::new(validation_workers, tokenizer, max_input_length);
    let shared_state = ServerState {
        validation,
        batcher,
        limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
    };

    // Create router
Olivier Dehaene's avatar
Olivier Dehaene committed
124
125
126
    let app = Router::new()
        .route("/generate", post(generate))
        .layer(Extension(shared_state.clone()))
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
127
        .route("/health", get(health))
Olivier Dehaene's avatar
Olivier Dehaene committed
128
        .layer(Extension(shared_state.clone()));
Olivier Dehaene's avatar
Olivier Dehaene committed
129

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
130
    // Run server
Olivier Dehaene's avatar
Olivier Dehaene committed
131
    axum::Server::bind(&addr)
Olivier Dehaene's avatar
Olivier Dehaene committed
132
        .serve(app.into_make_service())
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
133
134
        // Wait until all requests are finished to shut down
        .with_graceful_shutdown(shutdown_signal())
Olivier Dehaene's avatar
Olivier Dehaene committed
135
136
        .await
        .unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
137
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

/// Shutdown signal handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {},
        _ = terminate => {},
    }

    tracing::info!("signal received, starting graceful shutdown");
}