lib.rs 5.13 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
mod backend;
2
pub mod block_allocator;
Nicolas Patry's avatar
Nicolas Patry committed
3
4
mod client;
mod queue;
5
pub mod radix;
Nicolas Patry's avatar
Nicolas Patry committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;

#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct BackendInfo {
    /// Mandatory
    #[schema(example = "cuda")]
    pub model_device_type: String,
    #[schema(example = "torch.float16")]
    pub model_dtype: String,

    /// Backend parameters
    #[schema(example = "1")]
    pub speculate: usize,
    #[schema(example = "1.2")]
    pub waiting_served_ratio: f32,
    #[schema(example = "32000")]
    pub max_batch_total_tokens: u32,
    #[schema(example = "20")]
    pub max_waiting_tokens: usize,
    #[schema(nullable = true, example = "null")]
    pub max_batch_size: Option<usize>,
32
33
34
35
36
37
38
39
    #[schema(example = "false")]
    pub support_chunking: bool,
    #[schema(example = "false")]
    pub prefix_caching: bool,
    #[schema(example = "flashinfer")]
    pub attention_impl: String,
    #[schema(example = "1")]
    pub block_size: u32,
Nicolas Patry's avatar
Nicolas Patry committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
}

#[allow(clippy::too_many_arguments)]
pub async fn connect_backend(
    max_input_tokens: usize,
    max_total_tokens: usize,
    master_shard_uds_path: String,
    waiting_served_ratio: f32,
    max_batch_prefill_tokens: u32,
    max_batch_total_tokens: Option<u32>,
    max_waiting_tokens: usize,
    max_batch_size: Option<usize>,
) -> Result<(BackendV3, BackendInfo), V3Error> {
    // Helper function
    let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
        match max_supported_batch_total_tokens {
            // Older models do not support automatic max-batch-total-tokens
            None => {
                let max_batch_total_tokens = max_batch_total_tokens
                    .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
                tracing::warn!("Model does not support automatic max batch total tokens");
                Ok(max_batch_total_tokens)
            }
            // Flash attention models return their max supported total tokens
            Some(max_supported_batch_total_tokens) => {
                // Warn if user added his own max-batch-total-tokens as we will ignore it
                if max_batch_total_tokens.is_some() {
                    tracing::warn!(
                        "`--max-batch-total-tokens` is deprecated for Flash \
                        Attention models."
                    );
                    tracing::warn!(
                        "Inferred max batch total tokens: {max_supported_batch_total_tokens}"
                    );
                }
                if max_total_tokens as u32 > max_supported_batch_total_tokens {
                    return Err(V3Error::NotEnoughMemory(max_total_tokens));
                }

                Ok(max_supported_batch_total_tokens)
            }
        }
    };

    let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
        .await
        .map_err(V3Error::Connection)?;

    // server is running on v3
    // Clear the cache; useful if the webserver rebooted
    sharded_client
        .clear_cache(None)
        .await
        .map_err(V3Error::Cache)?;
    // Get info from the shard
    let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;

    // Warmup model
    tracing::info!("Warming up model");
    let max_batch_total_tokens = check_max_batch_total_tokens(
        sharded_client
            .warmup(
                max_input_tokens as u32,
                max_batch_prefill_tokens,
                max_total_tokens as u32,
                max_batch_size,
            )
            .await
            .map_err(V3Error::Warmup)?,
    )?;
    tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
111
    metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
Nicolas Patry's avatar
Nicolas Patry committed
112
113
114
115
116
117
118
119
120

    let backend_info = BackendInfo {
        waiting_served_ratio,
        max_batch_total_tokens,
        max_waiting_tokens,
        max_batch_size,
        model_device_type: shard_info.device_type.clone(),
        model_dtype: shard_info.dtype.clone(),
        speculate: shard_info.speculate as usize,
121
122
123
124
        support_chunking: shard_info.support_chunking,
        prefix_caching: shard_info.use_prefix_caching,
        attention_impl: shard_info.attention_impl.clone(),
        block_size: shard_info.block_size,
Nicolas Patry's avatar
Nicolas Patry committed
125
126
127
128
129
130
131
132
133
    };

    let backend = BackendV3::new(
        sharded_client,
        waiting_served_ratio,
        max_batch_prefill_tokens,
        max_batch_total_tokens,
        max_waiting_tokens,
        max_batch_size,
134
        shard_info,
Nicolas Patry's avatar
Nicolas Patry committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    );

    tracing::info!("Using backend V3");

    Ok((backend, backend_info))
}

#[derive(Debug, Error)]
pub enum V3Error {
    #[error("Unable to clear the Python model shards cache: {0}")]
    Cache(ClientError),
    #[error("Unable to connect to the Python model shards: {0}")]
    Connection(ClientError),
    #[error("Unable to get the Python model shards info: {0}")]
    Info(ClientError),
    #[error("Unable to warmup the Python model shards: {0}")]
    Warmup(ClientError),
    #[error("Not enough memory to handle `max_total_tokens={0}`")]
    NotEnoughMemory(usize),
}