main.rs 2.85 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Text Generation Inference webserver entrypoint
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
use bloom_inference_client::ShardedClient;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
3
use clap::Parser;
Olivier Dehaene's avatar
Olivier Dehaene committed
4
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
5
use std::time::Duration;
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use text_generation_router::server;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
Olivier Dehaene committed
8
9
10
11
12

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
14
15
16
17
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
    #[clap(default_value = "1000", long, env)]
    max_input_length: usize,
    #[clap(default_value = "32", long, env)]
Olivier Dehaene's avatar
Olivier Dehaene committed
18
    max_batch_size: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19
20
    #[clap(default_value = "5", long, env)]
    max_waiting_time: u64,
Olivier Dehaene's avatar
Olivier Dehaene committed
21
22
23
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
    #[clap(default_value = "/tmp/bloom-inference-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
24
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
25
26
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
27
28
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
29
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
30

Olivier Dehaene's avatar
Olivier Dehaene committed
31
fn main() -> Result<(), std::io::Error> {
Olivier Dehaene's avatar
Olivier Dehaene committed
32
33
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
34
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
35
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
36
37
        max_concurrent_requests,
        max_input_length,
Olivier Dehaene's avatar
Olivier Dehaene committed
38
        max_batch_size,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
39
        max_waiting_time,
Olivier Dehaene's avatar
Olivier Dehaene committed
40
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
41
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
42
        tokenizer_name,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
43
        validation_workers,
Olivier Dehaene's avatar
Olivier Dehaene committed
44
45
    } = args;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
46
47
48
49
50
    if validation_workers == 1 {
        panic!("validation_workers must be > 0");
    }

    let max_waiting_time = Duration::from_secs(max_waiting_time);
Olivier Dehaene's avatar
Olivier Dehaene committed
51

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
52
53
54
55
    // Download and instantiate tokenizer
    // This will only be used to validate payloads
    //
    // We need to download it outside of the Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
56
    let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
57

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
58
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
59
60
61
62
63
64
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
            tracing_subscriber::fmt::init();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
65

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
66
67
            // Instantiate sharded client from the master unix socket
            let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
68
69
                .await
                .expect("Could not connect to server");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
70
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
71
72
73
74
75
            sharded_client
                .clear_cache()
                .await
                .expect("Unable to clear cache");
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
76

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
77
            // Binds on localhost
Olivier Dehaene's avatar
Olivier Dehaene committed
78
            let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
79

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
80
81
82
83
84
85
86
87
88
89
90
91
            // Run server
            server::run(
                max_concurrent_requests,
                max_input_length,
                max_batch_size,
                max_waiting_time,
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
            )
            .await;
Olivier Dehaene's avatar
Olivier Dehaene committed
92
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
93
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
94
}