lib.rs 9.07 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::time::Duration;
5
use std::{future::Future, pin::Pin};
6

7
8
9
10
11
use anyhow::Context as _;
use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::entrypoint::EngineConfig;
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
use dynamo_runtime::CancellationToken;
12

13
14
mod flags;
pub use flags::Flags;
15
mod opt;
16
pub use dynamo_llm::request_template::RequestTemplate;
17
pub use opt::Output;
18
mod subprocess;
19

20
21
const CHILD_STOP_TIMEOUT: Duration = Duration::from_secs(2);

22
pub async fn run(
Neelay Shah's avatar
Neelay Shah committed
23
    runtime: dynamo_runtime::Runtime,
24
    in_opt: Input,
25
    out_opt: Option<Output>,
26
27
    flags: Flags,
) -> anyhow::Result<()> {
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    //
    // Configure
    //

    let mut builder = LocalModelBuilder::default();
    builder
        .model_path(
            flags
                .model_path_pos
                .clone()
                .or(flags.model_path_flag.clone()),
        )
        .model_name(flags.model_name.clone())
        .kv_cache_block_size(flags.kv_cache_block_size)
        // Only set if user provides. Usually loaded from tokenizer_config.json
        .context_length(flags.context_length)
        .http_port(flags.http_port)
        .router_config(flags.router_config())
        .request_template(flags.request_template.clone());

    // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
    // If not, then the endpoint isn't exposed so we let LocalModel invent one.
    if let Input::Endpoint(path) = &in_opt {
        builder.endpoint_id(path.parse().with_context(|| path.clone())?);
52
    };
53

54
    let local_model = builder.build().await?;
55

56
57
58
    //
    // Create an engine
    //
59

60
    let out_opt = out_opt.unwrap_or_else(|| default_engine_for(&local_model));
61
62
    print_cuda(&out_opt);

63
64
    // Now that we know the output we're targeting, check if we expect it to work
    flags.validate(&local_model, &out_opt)?;
65

66
67
68
    // Make an engine from the local_model, flags and output.
    let (engine_config, extra) =
        engine_for(runtime.primary_token(), out_opt, flags.clone(), local_model).await?;
69

70
71
72
    //
    // Run in from an input
    //
73

74
    dynamo_llm::entrypoint::input::run_input(in_opt, runtime, engine_config).await?;
75

76
77
78
79
80
    // Allow engines to ask main thread to wait on an extra future.
    // We use this to stop the vllm and sglang sub-process
    if let Some(extra) = extra {
        extra.await;
    }
81

82
83
    Ok(())
}
84

85
type ExtraFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
86

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
/// Create the engine matching `out_opt`
/// Note validation happens in Flags::validate. In here assume everything is going to work.
async fn engine_for(
    cancel_token: CancellationToken,
    out_opt: Output,
    flags: Flags,
    local_model: LocalModel,
) -> anyhow::Result<(EngineConfig, Option<ExtraFuture>)> {
    match out_opt {
        Output::Dynamic => Ok((EngineConfig::Dynamic(Box::new(local_model)), None)),
        Output::EchoFull => Ok((
            EngineConfig::StaticFull {
                model: Box::new(local_model),
                engine: dynamo_llm::engines::make_engine_full(),
            },
            None,
        )),
        Output::EchoCore => Ok((
            EngineConfig::StaticCore {
                engine: dynamo_llm::engines::make_engine_core(),
                model: Box::new(local_model),
            },
            None,
        )),
        #[cfg(feature = "mistralrs")]
        Output::MistralRs => Ok((
            EngineConfig::StaticFull {
                engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
                model: Box::new(local_model),
            },
            None,
        )),
119
        #[cfg(feature = "llamacpp")]
120
        Output::LlamaCpp => Ok((
121
            EngineConfig::StaticCore {
122
                engine: dynamo_engine_llamacpp::make_engine(cancel_token, &local_model).await?,
123
                model: Box::new(local_model),
124
125
126
127
128
129
130
131
132
133
134
            },
            None,
        )),
        // For multi-node config. vllm uses `ray`, see guide
        Output::Vllm => shell(subprocess::vllm::PY, cancel_token, local_model, flags, None).await,
        // For multi-node config. trtlllm uses `mpi`, see guide
        Output::Trtllm => {
            shell(
                subprocess::trtllm::PY,
                cancel_token,
                local_model,
135
                flags,
136
                None,
137
            )
138
            .await
139
        }
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        Output::SgLang => {
            let multi_node_config = if flags.num_nodes > 1 {
                Some(dynamo_llm::engines::MultiNodeConfig {
                    num_nodes: flags.num_nodes,
                    node_rank: flags.node_rank,
                    leader_addr: flags.leader_addr.clone().unwrap_or_default(),
                })
            } else {
                None
            };
            shell(
                subprocess::sglang::PY,
                cancel_token,
                local_model,
                flags,
                multi_node_config,
            )
            .await
158
159
        }
    }
160
}
161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
async fn shell(
    py_script: &'static str,
    cancel_token: CancellationToken,
    local_model: LocalModel,
    flags: Flags,
    multi_node_config: Option<dynamo_llm::engines::MultiNodeConfig>,
) -> anyhow::Result<(EngineConfig, Option<ExtraFuture>)> {
    let (py_script, child) =
        match subprocess::start(py_script, &local_model, flags.clone(), multi_node_config).await {
            Ok(x) => x,
            Err(err) => {
                anyhow::bail!("Failed starting engine sub-process: {err}");
            }
        };
176

177
178
179
180
181
    // Sub-process cleanup
    let extra: ExtraFuture = Box::pin(async move {
        stopper(cancel_token, child, py_script).await;
    });
    Ok((EngineConfig::Dynamic(Box::new(local_model)), Some(extra)))
182
}
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

/// Wait for cancel_token to be cancelled, then stop the child as gracefully as possible.
/// Keeps the TempPath alive until the child is stopped.
async fn stopper(
    cancel_token: CancellationToken,
    mut child: tokio::process::Child,
    py_script: tempfile::TempPath,
) {
    cancel_token.cancelled().await;

    // Ask subprocess to stop gracefully
    if let Some(pid) = child.id() {
        unsafe { libc::kill(pid as i32, libc::SIGTERM) };
    }

    tokio::select! {
        exit = child.wait() => {
200
            tracing::trace!("engine sub-process graceful exit");
201
202
203
204
            match exit {
                Ok(exit_status) if exit_status.success() => {}
                Ok(exit_status) => {
                    // This is nearly always 15 (SIGTERM)
205
                    tracing::trace!("engine sub-process non-0 exit: {exit_status}");
206
207
                }
                Err(err) => {
208
                    tracing::warn!("engine sub-process error getting exit status: {err}");
209
210
211
212
213
                }
            }
        }
        _ = tokio::time::sleep(CHILD_STOP_TIMEOUT) => {
            // It didn't stop in time, kill it
214
            child.kill().await.expect("Failed killing engine subprocess");
215
216
217
218
219
220
221
            let _ = child.wait().await;
        }
    }
    // This temporary file contains the python script running the engine. It deletes on drop.
    // Keep it alive until the engine has stopped.
    drop(py_script);
}
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

/// If the user will benefit from CUDA/Metal/Vulkan, remind them to build with it.
/// If they have it, celebrate!
// Only mistralrs and llamacpp need to be built with CUDA.
// The Python engines only need it at runtime.
#[cfg(any(feature = "mistralrs", feature = "llamacpp"))]
fn print_cuda(output: &Output) {
    // These engines maybe be compiled in, but are they the chosen one?
    match output {
        #[cfg(feature = "mistralrs")]
        Output::MistralRs => {}
        #[cfg(feature = "llamacpp")]
        Output::LlamaCpp => {}
        _ => {
            return;
        }
    }

    #[cfg(feature = "cuda")]
    {
        tracing::info!("CUDA on");
    }
    #[cfg(feature = "metal")]
    {
        tracing::info!("Metal on");
    }
    #[cfg(feature = "vulkan")]
    {
        tracing::info!("Vulkan on");
    }
    #[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
    tracing::info!("CPU mode. Rebuild with `--features cuda|metal|vulkan` for better performance");
}

#[cfg(not(any(feature = "mistralrs", feature = "llamacpp")))]
fn print_cuda(_output: &Output) {}
258

259
260
261
262
263
264
265
266
267
268
269
270
271
fn default_engine_for(local_model: &LocalModel) -> Output {
    let default_engine = if local_model.card().is_gguf() {
        gguf_default()
    } else {
        safetensors_default()
    };
    tracing::info!(
        "Using default engine: {default_engine}. Use out=<engine> to specify one of {}",
        Output::available_engines().join(", ")
    );
    default_engine
}

272
273
fn gguf_default() -> Output {
    #[cfg(feature = "llamacpp")]
274
275
276
    {
        Output::LlamaCpp
    }
277
278

    #[cfg(all(feature = "mistralrs", not(feature = "llamacpp")))]
279
280
281
    {
        Output::MistralRs
    }
282
283

    #[cfg(not(any(feature = "mistralrs", feature = "llamacpp")))]
284
285
286
    {
        Output::EchoFull
    }
287
288
289
290
}

fn safetensors_default() -> Output {
    #[cfg(feature = "mistralrs")]
291
292
293
    {
        Output::MistralRs
    }
294
295

    #[cfg(not(feature = "mistralrs"))]
296
297
298
299
    {
        Output::EchoFull
    }
}