lib.rs 6.15 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
mod backend;
2
pub mod block_allocator;
Nicolas Patry's avatar
Nicolas Patry committed
3
4
mod client;
mod queue;
5
pub mod radix;
Nicolas Patry's avatar
Nicolas Patry committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;

#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct BackendInfo {
    /// Mandatory
    #[schema(example = "cuda")]
    pub model_device_type: String,
    #[schema(example = "torch.float16")]
    pub model_dtype: String,

    /// Backend parameters
    #[schema(example = "1")]
    pub speculate: usize,
    #[schema(example = "1.2")]
    pub waiting_served_ratio: f32,
    #[schema(example = "32000")]
    pub max_batch_total_tokens: u32,
    #[schema(example = "20")]
    pub max_waiting_tokens: usize,
    #[schema(nullable = true, example = "null")]
    pub max_batch_size: Option<usize>,
32
33
34
35
36
37
38
39
    #[schema(example = "false")]
    pub support_chunking: bool,
    #[schema(example = "false")]
    pub prefix_caching: bool,
    #[schema(example = "flashinfer")]
    pub attention_impl: String,
    #[schema(example = "1")]
    pub block_size: u32,
40
41
42
43
44

    #[schema(example = "30000")]
    pub max_input_tokens: usize,
    #[schema(example = "32000")]
    pub max_total_tokens: usize,
Nicolas Patry's avatar
Nicolas Patry committed
45
46
47
48
}

#[allow(clippy::too_many_arguments)]
pub async fn connect_backend(
49
50
    max_input_tokens: Option<usize>,
    max_total_tokens: Option<usize>,
Nicolas Patry's avatar
Nicolas Patry committed
51
52
53
54
55
56
57
58
    master_shard_uds_path: String,
    waiting_served_ratio: f32,
    max_batch_prefill_tokens: u32,
    max_batch_total_tokens: Option<u32>,
    max_waiting_tokens: usize,
    max_batch_size: Option<usize>,
) -> Result<(BackendV3, BackendInfo), V3Error> {
    // Helper function
59
60
61
62
63
64
65
66
67
68
69
70
    let check_max_batch_total_tokens = |(
        max_supported_batch_total_tokens,
        shard_max_input_tokens,
        shard_max_total_tokens,
    ): (Option<u32>, u32, u32)|
     -> Result<(u32, usize, usize), V3Error> {
        if let Some(max_input_tokens) = max_input_tokens {
            assert_eq!(max_input_tokens as u32, shard_max_input_tokens);
        }
        if let Some(max_total_tokens) = max_total_tokens {
            assert_eq!(max_total_tokens as u32, shard_max_total_tokens);
        }
Nicolas Patry's avatar
Nicolas Patry committed
71
72
73
        match max_supported_batch_total_tokens {
            // Older models do not support automatic max-batch-total-tokens
            None => {
74
75
76
77
78
                let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
                    16000
                        .max(shard_max_total_tokens)
                        .max(max_batch_prefill_tokens),
                );
Nicolas Patry's avatar
Nicolas Patry committed
79
                tracing::warn!("Model does not support automatic max batch total tokens");
80
81
82
83
84
                Ok((
                    max_batch_total_tokens,
                    shard_max_input_tokens as usize,
                    shard_max_total_tokens as usize,
                ))
Nicolas Patry's avatar
Nicolas Patry committed
85
86
87
88
89
90
91
92
93
94
95
96
97
            }
            // Flash attention models return their max supported total tokens
            Some(max_supported_batch_total_tokens) => {
                // Warn if user added his own max-batch-total-tokens as we will ignore it
                if max_batch_total_tokens.is_some() {
                    tracing::warn!(
                        "`--max-batch-total-tokens` is deprecated for Flash \
                        Attention models."
                    );
                    tracing::warn!(
                        "Inferred max batch total tokens: {max_supported_batch_total_tokens}"
                    );
                }
98
99
                if shard_max_total_tokens > max_supported_batch_total_tokens {
                    return Err(V3Error::NotEnoughMemory(shard_max_total_tokens as usize));
Nicolas Patry's avatar
Nicolas Patry committed
100
101
                }

102
103
104
105
106
                Ok((
                    max_supported_batch_total_tokens,
                    shard_max_input_tokens as usize,
                    shard_max_total_tokens as usize,
                ))
Nicolas Patry's avatar
Nicolas Patry committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            }
        }
    };

    let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
        .await
        .map_err(V3Error::Connection)?;

    // server is running on v3
    // Clear the cache; useful if the webserver rebooted
    sharded_client
        .clear_cache(None)
        .await
        .map_err(V3Error::Cache)?;
    // Get info from the shard
    let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;

    // Warmup model
    tracing::info!("Warming up model");
126
127
128
129
130
131
132
133
134
135
136
    let answer = sharded_client
        .warmup(
            max_input_tokens.map(|p| p as u32),
            max_batch_prefill_tokens,
            max_total_tokens.map(|p| p as u32),
            max_batch_size,
        )
        .await
        .map_err(V3Error::Warmup)?;
    let (max_batch_total_tokens, max_input_tokens, max_total_tokens) =
        check_max_batch_total_tokens(answer)?;
Nicolas Patry's avatar
Nicolas Patry committed
137
    tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
138
    metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
Nicolas Patry's avatar
Nicolas Patry committed
139
140
141
142

    let backend_info = BackendInfo {
        waiting_served_ratio,
        max_batch_total_tokens,
143
144
        max_input_tokens,
        max_total_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
145
146
147
148
149
        max_waiting_tokens,
        max_batch_size,
        model_device_type: shard_info.device_type.clone(),
        model_dtype: shard_info.dtype.clone(),
        speculate: shard_info.speculate as usize,
150
151
152
153
        support_chunking: shard_info.support_chunking,
        prefix_caching: shard_info.use_prefix_caching,
        attention_impl: shard_info.attention_impl.clone(),
        block_size: shard_info.block_size,
Nicolas Patry's avatar
Nicolas Patry committed
154
155
156
157
158
159
160
161
162
    };

    let backend = BackendV3::new(
        sharded_client,
        waiting_served_ratio,
        max_batch_prefill_tokens,
        max_batch_total_tokens,
        max_waiting_tokens,
        max_batch_size,
163
        shard_info,
Nicolas Patry's avatar
Nicolas Patry committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    );

    tracing::info!("Using backend V3");

    Ok((backend, backend_info))
}

#[derive(Debug, Error)]
pub enum V3Error {
    #[error("Unable to clear the Python model shards cache: {0}")]
    Cache(ClientError),
    #[error("Unable to connect to the Python model shards: {0}")]
    Connection(ClientError),
    #[error("Unable to get the Python model shards info: {0}")]
    Info(ClientError),
    #[error("Unable to warmup the Python model shards: {0}")]
    Warmup(ClientError),
    #[error("Not enough memory to handle `max_total_tokens={0}`")]
    NotEnoughMemory(usize),
}