main.rs 2.93 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
use clap::Parser;
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
3
4
/// Text Generation Inference webserver entrypoint
use text_generation_client::ShardedClient;
Olivier Dehaene's avatar
Olivier Dehaene committed
5
use text_generation_router::server;
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
8
9
10
11

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
12
13
14
15
16
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
    #[clap(default_value = "1000", long, env)]
    max_input_length: usize,
    #[clap(default_value = "32", long, env)]
Olivier Dehaene's avatar
Olivier Dehaene committed
17
    max_batch_size: usize,
18
19
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
20
21
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
22
    #[clap(default_value = "/tmp/text-generation-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
23
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
24
25
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
26
27
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
28
29
    #[clap(long, env)]
    json_output: bool,
Olivier Dehaene's avatar
Olivier Dehaene committed
30
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
31

Olivier Dehaene's avatar
Olivier Dehaene committed
32
fn main() -> Result<(), std::io::Error> {
Olivier Dehaene's avatar
Olivier Dehaene committed
33
34
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
36
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
38
        max_concurrent_requests,
        max_input_length,
Olivier Dehaene's avatar
Olivier Dehaene committed
39
        max_batch_size,
40
        max_waiting_tokens,
Olivier Dehaene's avatar
Olivier Dehaene committed
41
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
42
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
43
        tokenizer_name,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
44
        validation_workers,
45
        json_output,
Olivier Dehaene's avatar
Olivier Dehaene committed
46
47
    } = args;

48
49
50
51
52
    if json_output {
        tracing_subscriber::fmt().json().init();
    } else {
        tracing_subscriber::fmt().compact().init();
    }
53

54
    if validation_workers == 0 {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
55
56
57
58
59
60
61
        panic!("validation_workers must be > 0");
    }

    // Download and instantiate tokenizer
    // This will only be used to validate payloads
    //
    // We need to download it outside of the Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
62
    let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
63

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
64
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
65
66
67
68
69
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
70
            // Instantiate sharded client from the master unix socket
71
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
72
73
                .await
                .expect("Could not connect to server");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
74
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
75
76
77
78
79
            sharded_client
                .clear_cache()
                .await
                .expect("Unable to clear cache");
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
80

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
81
            // Binds on localhost
Olivier Dehaene's avatar
Olivier Dehaene committed
82
            let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
83

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
84
85
86
87
88
            // Run server
            server::run(
                max_concurrent_requests,
                max_input_length,
                max_batch_size,
89
                max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
90
91
92
93
94
95
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
            )
            .await;
Olivier Dehaene's avatar
Olivier Dehaene committed
96
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
97
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
98
}