main.rs 9.55 KB
Newer Older
1
/// Text Generation Inference webserver entrypoint
2
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
3
use clap::Parser;
4
5
6
7
8
9
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
11
use std::path::Path;
12
use text_generation_client::ShardedClient;
13
use text_generation_router::{server, HubModelInfo};
14
use tokenizers::{FromPretrainedParameters, Tokenizer};
15
use tower_http::cors::AllowOrigin;
16
17
18
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
19
20
21
22
23

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
24
25
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
26
27
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
28
29
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
31
    #[clap(default_value = "1000", long, env)]
    max_input_length: usize,
32
33
    #[clap(default_value = "1512", long, env)]
    max_total_tokens: usize,
34
35
36
37
38
39
    #[clap(long, env)]
    max_batch_size: Option<usize>,
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
    #[clap(default_value = "32000", long, env)]
    max_batch_total_tokens: u32,
40
41
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
42
43
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
44
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
45
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
46
47
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
48
49
    #[clap(default_value = "main", long, env)]
    revision: String,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
50
51
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
52
53
    #[clap(long, env)]
    json_output: bool,
54
55
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
56
57
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
58
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
59

Olivier Dehaene's avatar
Olivier Dehaene committed
60
fn main() -> Result<(), std::io::Error> {
Olivier Dehaene's avatar
Olivier Dehaene committed
61
62
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
63
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
64
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
65
        max_concurrent_requests,
66
        max_best_of,
67
        max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
68
        max_input_length,
69
        max_total_tokens,
Olivier Dehaene's avatar
Olivier Dehaene committed
70
        max_batch_size,
71
72
        waiting_served_ratio,
        mut max_batch_total_tokens,
73
        max_waiting_tokens,
Olivier Dehaene's avatar
Olivier Dehaene committed
74
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
75
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
76
        tokenizer_name,
77
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
78
        validation_workers,
79
        json_output,
80
        otlp_endpoint,
81
        cors_allow_origin,
Olivier Dehaene's avatar
Olivier Dehaene committed
82
83
    } = args;

84
    if validation_workers == 0 {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
85
86
87
        panic!("validation_workers must be > 0");
    }

88
89
90
91
92
93
94
95
96
97
98
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

99
100
101
    // Parse Huggingface hub token
    let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();

102
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
103
    // This will only be used to validate payloads
104
    let local_path = Path::new(&tokenizer_name);
105
106
107
108
109
110
111
112
113
    let local_model = local_path.exists() && local_path.is_dir();
    let tokenizer = if local_model {
        // Load local tokenizer
        Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
    } else {
        // Download and instantiate tokenizer
        // We need to download it outside of the Tokio runtime
        let params = FromPretrainedParameters {
            revision: revision.clone(),
114
            auth_token: authorization_token.clone(),
115
            ..Default::default()
116
        };
117
118
        Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
    };
Olivier Dehaene's avatar
Olivier Dehaene committed
119

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
120
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
121
122
123
124
125
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
OlivierDehaene's avatar
OlivierDehaene committed
126
127
            init_logging(otlp_endpoint, json_output);

128
129
130
131
132
133
            if let Some(max_batch_size) = max_batch_size{
                tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
                max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
                tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
            }

134
135
136
137
138
139
140
            if tokenizer.is_none() {
                tracing::warn!(
                    "Could not find a fast tokenizer implementation for {tokenizer_name}"
                );
                tracing::warn!("Rust input length validation and truncation is disabled");
            }

141
142
            // Get Model info
            let model_info = match local_model {
143
                true => HubModelInfo {
144
145
146
147
                    model_id: tokenizer_name.clone(),
                    sha: None,
                    pipeline_tag: None,
                },
148
                false => get_model_info(&tokenizer_name, &revision, authorization_token).await,
149
            };
150
151

            // if pipeline-tag == text-generation we default to return_full_text = true
152
            let compat_return_full_text = match &model_info.pipeline_tag {
153
154
155
156
                None => {
                    tracing::warn!("no pipeline tag found for model {tokenizer_name}");
                    false
                }
157
                Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
158
159
            };

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
160
            // Instantiate sharded client from the master unix socket
161
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
162
163
                .await
                .expect("Could not connect to server");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
164
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
165
            sharded_client
166
                .clear_cache(None)
Olivier Dehaene's avatar
Olivier Dehaene committed
167
168
                .await
                .expect("Unable to clear cache");
169
170
171
172
173
            // Get info from the shard
            let shard_info = sharded_client
                .info()
                .await
                .expect("Unable to get shard info");
Olivier Dehaene's avatar
Olivier Dehaene committed
174
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
175

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
176
            // Binds on localhost
Olivier Dehaene's avatar
Olivier Dehaene committed
177
            let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
178

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
179
180
            // Run server
            server::run(
181
                model_info,
182
                shard_info,
183
                compat_return_full_text,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
184
                max_concurrent_requests,
185
                max_best_of,
186
                max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
187
                max_input_length,
188
                max_total_tokens,
189
190
                waiting_served_ratio,
                max_batch_total_tokens,
191
                max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
192
193
194
195
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
196
                cors_allow_origin,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
197
198
            )
            .await;
Olivier Dehaene's avatar
Olivier Dehaene committed
199
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
200
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
201
}
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
                        "text-generation-inference.router",
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
            axum_tracing_opentelemetry::init_propagator().unwrap();
        };
    }

    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
257
258

/// get model info from the Huggingface Hub
259
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> HubModelInfo {
260
261
    let client = reqwest::Client::new();
    let mut builder = client.get(format!(
262
        "https://huggingface.co/api/models/{model_id}/revision/{revision}"
263
264
265
266
267
268
269
270
271
272
273
274
    ));
    if let Some(token) = token {
        builder = builder.bearer_auth(token);
    }

    let model_info = builder
        .send()
        .await
        .expect("Could not connect to hf.co")
        .text()
        .await
        .expect("error when retrieving model info from hf.co");
275
276
    serde_json::from_str(&model_info).expect("unable to parse model info")
}