main.rs 1.59 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
use bloom_inference_client::ShardedClient;
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
Olivier Dehaene's avatar
Olivier Dehaene committed
3
use text_generation_router::server;
Olivier Dehaene's avatar
Olivier Dehaene committed
4
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
Olivier Dehaene committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
use clap::Parser;

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
    #[clap(default_value = "32", long, short, env)]
    max_batch_size: usize,
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
    #[clap(default_value = "/tmp/bloom-inference-0", long, env)]
    shard_uds_path: String,
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
20

Olivier Dehaene's avatar
Olivier Dehaene committed
21
fn main() -> Result<(), std::io::Error> {
Olivier Dehaene's avatar
Olivier Dehaene committed
22
23
24
25
26
27
28
29
30
31
32
33
    // Get args
    let args = Args::parse();
// Pattern match configuration
    let Args {
        max_batch_size,
        port,
        shard_uds_path,
        tokenizer_name,
    } = args;


    let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
34
35
36
37
38
39
40

    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
            tracing_subscriber::fmt::init();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41

Olivier Dehaene's avatar
Olivier Dehaene committed
42
            let sharded_client = ShardedClient::connect_uds(shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
43
44
                .await
                .expect("Could not connect to server");
Olivier Dehaene's avatar
Olivier Dehaene committed
45
46
47
48
49
            sharded_client
                .clear_cache()
                .await
                .expect("Unable to clear cache");
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
50

Olivier Dehaene's avatar
Olivier Dehaene committed
51
            let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
52

Olivier Dehaene's avatar
Olivier Dehaene committed
53
            server::run(max_batch_size, sharded_client, tokenizer, addr).await;
Olivier Dehaene's avatar
Olivier Dehaene committed
54
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
55
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
56
}