main.rs 9.85 KB
Newer Older
1
/// Text Generation Inference webserver entrypoint
2
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
3
use clap::Parser;
4
5
6
7
8
9
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
11
use std::path::Path;
12
use std::time::Duration;
13
use text_generation_client::ShardedClient;
14
use text_generation_router::{server, HubModelInfo};
15
use tokenizers::{FromPretrainedParameters, Tokenizer};
16
use tower_http::cors::AllowOrigin;
17
18
19
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
20
21
22
23
24

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
25
26
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
27
28
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
29
30
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
31
32
    #[clap(default_value = "1000", long, env)]
    max_input_length: usize,
33
34
    #[clap(default_value = "1512", long, env)]
    max_total_tokens: usize,
35
36
37
38
39
40
    #[clap(long, env)]
    max_batch_size: Option<usize>,
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
    #[clap(default_value = "32000", long, env)]
    max_batch_total_tokens: u32,
41
42
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
43
44
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
45
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
46
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
47
48
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
49
50
    #[clap(default_value = "main", long, env)]
    revision: String,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
51
52
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
53
54
    #[clap(long, env)]
    json_output: bool,
55
56
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
57
58
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
59
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
60

Olivier Dehaene's avatar
Olivier Dehaene committed
61
fn main() -> Result<(), std::io::Error> {
Olivier Dehaene's avatar
Olivier Dehaene committed
62
63
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
64
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
65
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
66
        max_concurrent_requests,
67
        max_best_of,
68
        max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
69
        max_input_length,
70
        max_total_tokens,
Olivier Dehaene's avatar
Olivier Dehaene committed
71
        max_batch_size,
72
73
        waiting_served_ratio,
        mut max_batch_total_tokens,
74
        max_waiting_tokens,
Olivier Dehaene's avatar
Olivier Dehaene committed
75
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
76
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
77
        tokenizer_name,
78
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
79
        validation_workers,
80
        json_output,
81
        otlp_endpoint,
82
        cors_allow_origin,
Olivier Dehaene's avatar
Olivier Dehaene committed
83
84
    } = args;

85
    if validation_workers == 0 {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
86
87
88
        panic!("validation_workers must be > 0");
    }

89
90
91
92
93
94
95
96
97
98
99
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

100
101
102
    // Parse Huggingface hub token
    let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();

103
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
104
    // This will only be used to validate payloads
105
    let local_path = Path::new(&tokenizer_name);
106
107
108
109
110
111
112
113
114
    let local_model = local_path.exists() && local_path.is_dir();
    let tokenizer = if local_model {
        // Load local tokenizer
        Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
    } else {
        // Download and instantiate tokenizer
        // We need to download it outside of the Tokio runtime
        let params = FromPretrainedParameters {
            revision: revision.clone(),
115
            auth_token: authorization_token.clone(),
116
            ..Default::default()
117
        };
118
119
        Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
    };
Olivier Dehaene's avatar
Olivier Dehaene committed
120

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
121
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
122
123
124
125
126
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
OlivierDehaene's avatar
OlivierDehaene committed
127
128
            init_logging(otlp_endpoint, json_output);

129
            if let Some(max_batch_size) = max_batch_size {
130
131
132
133
134
                tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
                max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
                tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
            }

135
136
137
138
139
140
141
            if tokenizer.is_none() {
                tracing::warn!(
                    "Could not find a fast tokenizer implementation for {tokenizer_name}"
                );
                tracing::warn!("Rust input length validation and truncation is disabled");
            }

142
143
            // Get Model info
            let model_info = match local_model {
144
                true => HubModelInfo {
145
146
147
148
                    model_id: tokenizer_name.clone(),
                    sha: None,
                    pipeline_tag: None,
                },
149
                false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
150
151
152
                    tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                    HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
                }),
153
            };
154
155

            // if pipeline-tag == text-generation we default to return_full_text = true
156
            let compat_return_full_text = match &model_info.pipeline_tag {
157
158
159
160
                None => {
                    tracing::warn!("no pipeline tag found for model {tokenizer_name}");
                    false
                }
161
                Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
162
163
            };

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
164
            // Instantiate sharded client from the master unix socket
165
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
166
167
                .await
                .expect("Could not connect to server");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
168
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
169
            sharded_client
170
                .clear_cache(None)
Olivier Dehaene's avatar
Olivier Dehaene committed
171
172
                .await
                .expect("Unable to clear cache");
173
174
175
176
177
            // Get info from the shard
            let shard_info = sharded_client
                .info()
                .await
                .expect("Unable to get shard info");
Olivier Dehaene's avatar
Olivier Dehaene committed
178
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
179

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
180
            // Binds on localhost
Olivier Dehaene's avatar
Olivier Dehaene committed
181
            let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
182

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
183
184
            // Run server
            server::run(
185
                model_info,
186
                shard_info,
187
                compat_return_full_text,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
188
                max_concurrent_requests,
189
                max_best_of,
190
                max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
191
                max_input_length,
192
                max_total_tokens,
193
194
                waiting_served_ratio,
                max_batch_total_tokens,
195
                max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
196
197
198
199
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
200
                cors_allow_origin,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
201
            )
202
                .await;
Olivier Dehaene's avatar
Olivier Dehaene committed
203
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
204
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
205
}
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
                        "text-generation-inference.router",
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
            axum_tracing_opentelemetry::init_propagator().unwrap();
        };
    }

    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
261
262

/// get model info from the Huggingface Hub
263
264
265
266
267
pub async fn get_model_info(
    model_id: &str,
    revision: &str,
    token: Option<String>,
) -> Option<HubModelInfo> {
268
    let client = reqwest::Client::new();
269
    // Poor man's urlencode
270
    let revision = revision.replace('/', "%2F");
271
    let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
272
    let mut builder = client.get(url).timeout(Duration::from_secs(5));
273
274
275
276
    if let Some(token) = token {
        builder = builder.bearer_auth(token);
    }

277
278
279
280
281
282
    let response = builder.send().await.ok()?;

    if response.status().is_success() {
        return serde_json::from_str(&response.text().await.ok()?).ok();
    }
    None
283
}