main.rs 11.5 KB
Newer Older
1
/// Text Generation Inference webserver entrypoint
2
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
3
use clap::Parser;
4
5
6
7
8
9
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
11
use std::path::Path;
12
use std::time::Duration;
13
use text_generation_client::ShardedClient;
14
use text_generation_router::{server, HubModelInfo};
15
use tokenizers::{FromPretrainedParameters, Tokenizer};
16
use tower_http::cors::AllowOrigin;
17
18
19
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
20
21
22
23
24

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
25
26
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
27
28
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
29
30
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
31
    #[clap(default_value = "1024", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
32
    max_input_length: usize,
33
    #[clap(default_value = "2048", long, env)]
34
    max_total_tokens: usize,
35
36
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
37
38
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
39
    #[clap(default_value = "16000", long, env)]
40
    max_batch_total_tokens: u32,
41
42
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
43
44
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
45
46
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
47
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
48
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
49
50
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
51
52
    #[clap(default_value = "main", long, env)]
    revision: String,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
53
54
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
55
56
    #[clap(long, env)]
    json_output: bool,
57
58
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
59
60
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
61
62
63
64
65
66
67
68
69
70
    #[clap(long, env)]
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
    ngrok_domain: Option<String>,
    #[clap(long, env)]
    ngrok_username: Option<String>,
    #[clap(long, env)]
    ngrok_password: Option<String>,
Olivier Dehaene's avatar
Olivier Dehaene committed
71
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
72

Olivier Dehaene's avatar
Olivier Dehaene committed
73
fn main() -> Result<(), std::io::Error> {
Olivier Dehaene's avatar
Olivier Dehaene committed
74
75
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
76
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
77
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
78
        max_concurrent_requests,
79
        max_best_of,
80
        max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
81
        max_input_length,
82
        max_total_tokens,
83
        waiting_served_ratio,
84
85
        max_batch_prefill_tokens,
        max_batch_total_tokens,
86
        max_waiting_tokens,
87
        hostname,
Olivier Dehaene's avatar
Olivier Dehaene committed
88
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
89
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
90
        tokenizer_name,
91
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
92
        validation_workers,
93
        json_output,
94
        otlp_endpoint,
95
        cors_allow_origin,
96
97
98
99
100
        ngrok,
        ngrok_authtoken,
        ngrok_domain,
        ngrok_username,
        ngrok_password,
Olivier Dehaene's avatar
Olivier Dehaene committed
101
102
    } = args;

103
104
105
106
107
108
109
110
111
112
    // Validate args
    if max_input_length as u32 > max_batch_prefill_tokens {
        panic!("{}", format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"));
    }
    if max_batch_prefill_tokens > max_batch_total_tokens {
        panic!("{}", format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"));
    }
    if max_total_tokens as u32 > max_batch_total_tokens {
        panic!("{}", format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"));
    }
113
    if validation_workers == 0 {
114
        panic!("`validation_workers` must be > 0");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
115
116
    }

117
118
119
120
121
122
123
124
125
126
127
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

128
129
130
    // Parse Huggingface hub token
    let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();

131
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
132
    // This will only be used to validate payloads
133
    let local_path = Path::new(&tokenizer_name);
134
135
136
137
138
139
140
141
142
    let local_model = local_path.exists() && local_path.is_dir();
    let tokenizer = if local_model {
        // Load local tokenizer
        Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
    } else {
        // Download and instantiate tokenizer
        // We need to download it outside of the Tokio runtime
        let params = FromPretrainedParameters {
            revision: revision.clone(),
143
            auth_token: authorization_token.clone(),
144
            ..Default::default()
145
        };
146
147
        Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
    };
Olivier Dehaene's avatar
Olivier Dehaene committed
148

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
149
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
150
151
152
153
154
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
OlivierDehaene's avatar
OlivierDehaene committed
155
156
            init_logging(otlp_endpoint, json_output);

157
158
159
160
161
162
163
            if tokenizer.is_none() {
                tracing::warn!(
                    "Could not find a fast tokenizer implementation for {tokenizer_name}"
                );
                tracing::warn!("Rust input length validation and truncation is disabled");
            }

164
165
            // Get Model info
            let model_info = match local_model {
166
                true => HubModelInfo {
167
168
169
170
                    model_id: tokenizer_name.clone(),
                    sha: None,
                    pipeline_tag: None,
                },
171
172
173
174
175
176
177
178
179
180
                false => get_model_info(&tokenizer_name, &revision, authorization_token)
                    .await
                    .unwrap_or_else(|| {
                        tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                        HubModelInfo {
                            model_id: tokenizer_name.to_string(),
                            sha: None,
                            pipeline_tag: None,
                        }
                    }),
181
            };
182
183

            // if pipeline-tag == text-generation we default to return_full_text = true
184
            let compat_return_full_text = match &model_info.pipeline_tag {
185
186
187
188
                None => {
                    tracing::warn!("no pipeline tag found for model {tokenizer_name}");
                    false
                }
189
                Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
190
191
            };

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
192
            // Instantiate sharded client from the master unix socket
193
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
194
195
                .await
                .expect("Could not connect to server");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
196
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
197
            sharded_client
198
                .clear_cache(None)
Olivier Dehaene's avatar
Olivier Dehaene committed
199
200
                .await
                .expect("Unable to clear cache");
201
202
203
204
205
            // Get info from the shard
            let shard_info = sharded_client
                .info()
                .await
                .expect("Unable to get shard info");
206
207
208
209
210
211
212
213
214
215
216

            // Warmup model
            tracing::info!("Warming up model");
            sharded_client
                .warmup(
                    max_input_length as u32,
                    max_batch_prefill_tokens,
                    max_batch_total_tokens,
                )
                .await
                .expect("Unable to warmup model");
Olivier Dehaene's avatar
Olivier Dehaene committed
217
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
218

219
220
221
222
223
224
225
            let addr = match hostname.parse() {
                Ok(ip) => SocketAddr::new(ip, port),
                Err(_) => {
                    tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
                    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
                }
            };
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
226

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
227
228
            // Run server
            server::run(
229
                model_info,
230
                shard_info,
231
                compat_return_full_text,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
232
                max_concurrent_requests,
233
                max_best_of,
234
                max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
235
                max_input_length,
236
                max_total_tokens,
237
                waiting_served_ratio,
238
                max_batch_prefill_tokens,
239
                max_batch_total_tokens,
240
                max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
241
242
243
244
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
245
                cors_allow_origin,
246
247
248
249
250
                ngrok,
                ngrok_authtoken,
                ngrok_domain,
                ngrok_username,
                ngrok_password,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
251
            )
252
            .await;
Olivier Dehaene's avatar
Olivier Dehaene committed
253
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
254
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
255
}
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
                        "text-generation-inference.router",
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
            axum_tracing_opentelemetry::init_propagator().unwrap();
        };
    }

    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
311
312

/// get model info from the Huggingface Hub
313
314
315
316
317
pub async fn get_model_info(
    model_id: &str,
    revision: &str,
    token: Option<String>,
) -> Option<HubModelInfo> {
318
    let client = reqwest::Client::new();
319
    // Poor man's urlencode
320
    let revision = revision.replace('/', "%2F");
321
    let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
322
    let mut builder = client.get(url).timeout(Duration::from_secs(5));
323
324
325
326
    if let Some(token) = token {
        builder = builder.bearer_auth(token);
    }

327
328
329
330
331
332
    let response = builder.send().await.ok()?;

    if response.status().is_success() {
        return serde_json::from_str(&response.text().await.ok()?).ok();
    }
    None
333
}