main.rs 7.36 KB
Newer Older
1
/// Text Generation Inference webserver entrypoint
2
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
3
use clap::Parser;
4
5
6
7
8
9
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
11
use std::path::Path;
12
use text_generation_client::ShardedClient;
Olivier Dehaene's avatar
Olivier Dehaene committed
13
use text_generation_router::server;
Olivier Dehaene's avatar
Olivier Dehaene committed
14
use tokenizers::Tokenizer;
15
use tower_http::cors::AllowOrigin;
16
17
18
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
19
20
21
22
23

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
24
25
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
26
27
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
28
29
    #[clap(default_value = "1000", long, env)]
    max_input_length: usize,
30
31
    #[clap(default_value = "1512", long, env)]
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
32
    #[clap(default_value = "32", long, env)]
Olivier Dehaene's avatar
Olivier Dehaene committed
33
    max_batch_size: usize,
34
35
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
36
37
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
38
    #[clap(default_value = "/tmp/text-generation-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
39
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
40
41
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
42
43
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
44
45
    #[clap(long, env)]
    json_output: bool,
46
47
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
48
49
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
Olivier Dehaene's avatar
Olivier Dehaene committed
50
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
51

Olivier Dehaene's avatar
Olivier Dehaene committed
52
fn main() -> Result<(), std::io::Error> {
Olivier Dehaene's avatar
Olivier Dehaene committed
53
54
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
55
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
56
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
57
        max_concurrent_requests,
58
        max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
59
        max_input_length,
60
        max_total_tokens,
Olivier Dehaene's avatar
Olivier Dehaene committed
61
        max_batch_size,
62
        max_waiting_tokens,
Olivier Dehaene's avatar
Olivier Dehaene committed
63
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
64
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
65
        tokenizer_name,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
66
        validation_workers,
67
        json_output,
68
        otlp_endpoint,
69
        cors_allow_origin,
Olivier Dehaene's avatar
Olivier Dehaene committed
70
71
    } = args;

72
    if validation_workers == 0 {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
73
74
75
        panic!("validation_workers must be > 0");
    }

76
77
78
79
80
81
82
83
84
85
86
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

87
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
88
    // This will only be used to validate payloads
89
90
91
92
93
94
95
96
97
98
99
    let local_path = Path::new(&tokenizer_name);
    let tokenizer =
        if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
        {
            // Load local tokenizer
            Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
        } else {
            // Download and instantiate tokenizer
            // We need to download it outside of the Tokio runtime
            Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
        };
Olivier Dehaene's avatar
Olivier Dehaene committed
100

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
101
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
102
103
104
105
106
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
107
108
            init_logging(otlp_endpoint, json_output);

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
            // Get pipeline tag
            let model_info = reqwest::get(format!(
                "https://huggingface.co/api/models/{tokenizer_name}"
            ))
            .await
            .expect("Could not connect to hf.co")
            .text()
            .await
            .expect("error when retrieving model info from hf.co");
            let model_info: serde_json::Value =
                serde_json::from_str(&model_info).expect("unable to parse model info");

            // if pipeline-tag == text-generation we default to return_full_text = true
            let compat_return_full_text = match model_info.get("pipeline_tag") {
                None => {
                    tracing::warn!("no pipeline tag found for model {tokenizer_name}");
                    false
                }
                Some(pipeline_tag) => pipeline_tag.as_str() == Some("text-generation"),
            };

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
130
            // Instantiate sharded client from the master unix socket
131
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
132
133
                .await
                .expect("Could not connect to server");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
134
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
135
136
137
138
139
            sharded_client
                .clear_cache()
                .await
                .expect("Unable to clear cache");
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
140

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
141
            // Binds on localhost
Olivier Dehaene's avatar
Olivier Dehaene committed
142
            let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
143

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
144
145
            // Run server
            server::run(
146
                compat_return_full_text,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
147
                max_concurrent_requests,
148
                max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
149
                max_input_length,
150
                max_total_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
151
                max_batch_size,
152
                max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
153
154
155
156
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
157
                cors_allow_origin,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
158
159
            )
            .await;
Olivier Dehaene's avatar
Olivier Dehaene committed
160
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
161
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
162
}
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
                        "text-generation-inference.router",
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
            axum_tracing_opentelemetry::init_propagator().unwrap();
        };
    }

    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}