main.rs 12.4 KB
Newer Older
1
/// Text Generation Inference webserver entrypoint
2
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
3
use clap::Parser;
4
5
6
7
8
9
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
11
use std::path::Path;
12
use std::time::Duration;
13
use text_generation_client::{ClientError, ShardedClient};
14
use text_generation_router::{server, HubModelInfo};
15
use thiserror::Error;
16
use tokenizers::{FromPretrainedParameters, Tokenizer};
17
use tower_http::cors::AllowOrigin;
18
19
20
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
21
22
23
24
25

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
26
27
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
28
29
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
30
31
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
32
    #[clap(default_value = "1024", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
33
    max_input_length: usize,
34
    #[clap(default_value = "2048", long, env)]
35
    max_total_tokens: usize,
36
37
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
38
39
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
40
    #[clap(default_value = "16000", long, env)]
41
    max_batch_total_tokens: u32,
42
43
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
44
45
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
46
47
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
48
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
49
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
50
51
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
52
53
    #[clap(default_value = "main", long, env)]
    revision: String,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
54
55
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
56
57
    #[clap(long, env)]
    json_output: bool,
58
59
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
60
61
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
62
63
64
65
66
67
68
69
70
71
    #[clap(long, env)]
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
    ngrok_domain: Option<String>,
    #[clap(long, env)]
    ngrok_username: Option<String>,
    #[clap(long, env)]
    ngrok_password: Option<String>,
Olivier Dehaene's avatar
Olivier Dehaene committed
72
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
73

74
fn main() -> Result<(), RouterError> {
Olivier Dehaene's avatar
Olivier Dehaene committed
75
76
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
77
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
78
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
79
        max_concurrent_requests,
80
        max_best_of,
81
        max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
82
        max_input_length,
83
        max_total_tokens,
84
        waiting_served_ratio,
85
86
        max_batch_prefill_tokens,
        max_batch_total_tokens,
87
        max_waiting_tokens,
88
        hostname,
Olivier Dehaene's avatar
Olivier Dehaene committed
89
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
90
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
91
        tokenizer_name,
92
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
93
        validation_workers,
94
        json_output,
95
        otlp_endpoint,
96
        cors_allow_origin,
97
98
99
100
101
        ngrok,
        ngrok_authtoken,
        ngrok_domain,
        ngrok_username,
        ngrok_password,
Olivier Dehaene's avatar
Olivier Dehaene committed
102
103
    } = args;

104
    // Validate args
105
106
107
108
109
    if max_input_length >= max_total_tokens {
        return Err(RouterError::ArgumentValidation(
            "`max_input_length` must be < `max_total_tokens`".to_string(),
        ));
    }
110
    if max_input_length as u32 > max_batch_prefill_tokens {
111
        return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
112
113
    }
    if max_batch_prefill_tokens > max_batch_total_tokens {
114
        return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
115
116
    }
    if max_total_tokens as u32 > max_batch_total_tokens {
117
        return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
118
    }
119
    if validation_workers == 0 {
120
121
122
        return Err(RouterError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
123
124
    }

125
126
127
128
129
130
131
132
133
134
135
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

136
137
138
    // Parse Huggingface hub token
    let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();

139
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
140
    // This will only be used to validate payloads
141
    let local_path = Path::new(&tokenizer_name);
142
143
144
145
146
147
148
149
150
    let local_model = local_path.exists() && local_path.is_dir();
    let tokenizer = if local_model {
        // Load local tokenizer
        Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
    } else {
        // Download and instantiate tokenizer
        // We need to download it outside of the Tokio runtime
        let params = FromPretrainedParameters {
            revision: revision.clone(),
151
            auth_token: authorization_token.clone(),
152
            ..Default::default()
153
        };
154
155
        Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
    };
Olivier Dehaene's avatar
Olivier Dehaene committed
156

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
157
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
158
159
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
160
        .build()?
Olivier Dehaene's avatar
Olivier Dehaene committed
161
        .block_on(async {
OlivierDehaene's avatar
OlivierDehaene committed
162
163
            init_logging(otlp_endpoint, json_output);

164
165
166
167
168
169
170
            if tokenizer.is_none() {
                tracing::warn!(
                    "Could not find a fast tokenizer implementation for {tokenizer_name}"
                );
                tracing::warn!("Rust input length validation and truncation is disabled");
            }

171
172
            // Get Model info
            let model_info = match local_model {
173
                true => HubModelInfo {
174
175
176
177
                    model_id: tokenizer_name.clone(),
                    sha: None,
                    pipeline_tag: None,
                },
178
179
180
181
182
183
184
185
186
187
                false => get_model_info(&tokenizer_name, &revision, authorization_token)
                    .await
                    .unwrap_or_else(|| {
                        tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                        HubModelInfo {
                            model_id: tokenizer_name.to_string(),
                            sha: None,
                            pipeline_tag: None,
                        }
                    }),
188
            };
189
190

            // if pipeline-tag == text-generation we default to return_full_text = true
191
            let compat_return_full_text = match &model_info.pipeline_tag {
192
193
194
195
                None => {
                    tracing::warn!("no pipeline tag found for model {tokenizer_name}");
                    false
                }
196
                Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
197
198
            };

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
199
            // Instantiate sharded client from the master unix socket
200
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
201
                .await
202
                .map_err(RouterError::Connection)?;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
203
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
204
            sharded_client
205
                .clear_cache(None)
Olivier Dehaene's avatar
Olivier Dehaene committed
206
                .await
207
                .map_err(RouterError::Cache)?;
208
            // Get info from the shard
209
            let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
210
211
212
213
214
215
216
217
218
219

            // Warmup model
            tracing::info!("Warming up model");
            sharded_client
                .warmup(
                    max_input_length as u32,
                    max_batch_prefill_tokens,
                    max_batch_total_tokens,
                )
                .await
220
                .map_err(RouterError::Warmup)?;
Olivier Dehaene's avatar
Olivier Dehaene committed
221
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
222

223
224
225
226
227
228
229
            let addr = match hostname.parse() {
                Ok(ip) => SocketAddr::new(ip, port),
                Err(_) => {
                    tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
                    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
                }
            };
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
230

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
231
232
            // Run server
            server::run(
233
                model_info,
234
                shard_info,
235
                compat_return_full_text,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
236
                max_concurrent_requests,
237
                max_best_of,
238
                max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
239
                max_input_length,
240
                max_total_tokens,
241
                waiting_served_ratio,
242
                max_batch_prefill_tokens,
243
                max_batch_total_tokens,
244
                max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
245
246
247
248
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
249
                cors_allow_origin,
250
251
252
253
254
                ngrok,
                ngrok_authtoken,
                ngrok_domain,
                ngrok_username,
                ngrok_password,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
255
            )
256
            .await?;
Olivier Dehaene's avatar
Olivier Dehaene committed
257
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
258
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
259
}
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
                        "text-generation-inference.router",
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
            axum_tracing_opentelemetry::init_propagator().unwrap();
        };
    }

    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
315
316

/// get model info from the Huggingface Hub
317
318
319
320
321
pub async fn get_model_info(
    model_id: &str,
    revision: &str,
    token: Option<String>,
) -> Option<HubModelInfo> {
322
    let client = reqwest::Client::new();
323
    // Poor man's urlencode
324
    let revision = revision.replace('/', "%2F");
325
    let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
326
    let mut builder = client.get(url).timeout(Duration::from_secs(5));
327
328
329
330
    if let Some(token) = token {
        builder = builder.bearer_auth(token);
    }

331
332
333
334
335
336
    let response = builder.send().await.ok()?;

    if response.status().is_success() {
        return serde_json::from_str(&response.text().await.ok()?).ok();
    }
    None
337
}
338
339
340

#[derive(Debug, Error)]
enum RouterError {
341
342
    #[error("Argument validation error: {0}")]
    ArgumentValidation(String),
343
344
345
346
347
348
349
350
351
352
353
354
355
    #[error("Unable to connect to the Python model shards: {0}")]
    Connection(ClientError),
    #[error("Unable to clear the Python model shards cache: {0}")]
    Cache(ClientError),
    #[error("Unable to get the Python model shards info: {0}")]
    Info(ClientError),
    #[error("Unable to warmup the Python model shards: {0}")]
    Warmup(ClientError),
    #[error("Tokio runtime failed to start: {0}")]
    Tokio(#[from] std::io::Error),
    #[error("Axum webserver failed: {0}")]
    Axum(#[from] axum::BoxError),
}