lib.rs 14.6 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::{future::Future, pin::Pin};
5
use std::{io::Read, sync::Arc, time::Duration};
6

7
use anyhow::Context;
8
use dynamo_llm::{backend::ExecutionContext, engines::StreamingEngine, local_model::LocalModel};
9
10
use dynamo_runtime::protocols::Endpoint as EndpointId;
use dynamo_runtime::slug::Slug;
11
use dynamo_runtime::{CancellationToken, DistributedRuntime};
12

13
14
mod flags;
pub use flags::Flags;
15
16
mod input;
mod opt;
17
pub use dynamo_llm::request_template::RequestTemplate;
18
pub use opt::{Input, Output};
19
mod subprocess;
20

21
22
const CHILD_STOP_TIMEOUT: Duration = Duration::from_secs(2);

23
24
25
/// Default size of a KV cache block. Override with --kv-cache-block-size
const DEFAULT_KV_CACHE_BLOCK_SIZE: usize = 16;

26
pub enum EngineConfig {
27
28
    /// Remote networked engines
    Dynamic,
29

30
31
    /// A Full service engine does it's own tokenization and prompt formatting.
    StaticFull {
32
        engine: Arc<dyn StreamingEngine>,
33
        model: Box<LocalModel>,
34
    },
35
36
37
38

    /// A core engine expects to be wrapped with pre/post processors that handle tokenization.
    StaticCore {
        engine: ExecutionContext,
39
        model: Box<LocalModel>,
40
    },
41
42
}

43
44
45
46
47
48
49
50
fn is_in_dynamic(in_opt: &Input) -> bool {
    matches!(in_opt, Input::Endpoint(_))
}

fn is_out_dynamic(out_opt: &Option<Output>) -> bool {
    matches!(out_opt, Some(Output::Dynamic))
}

51
pub async fn run(
Neelay Shah's avatar
Neelay Shah committed
52
    runtime: dynamo_runtime::Runtime,
53
    in_opt: Input,
54
    out_opt: Option<Output>,
55
56
    flags: Flags,
) -> anyhow::Result<()> {
57
    if is_in_dynamic(&in_opt) && is_out_dynamic(&out_opt) {
58
59
60
        anyhow::bail!("Cannot use endpoint for both in and out");
    }

61
    let cancel_token = runtime.primary_token();
62
    let maybe_path = flags
63
        .model_path_pos
64
        .clone()
65
        .or(flags.model_path_flag.clone());
66

67
    let mut local_model: LocalModel = if is_out_dynamic(&out_opt) {
68
        // If output is dynamic we are ingress and don't have a local model, but making an
69
        // empty one cleans up the code.
70
71
        Default::default()
    } else {
72
        // All other output types have a local model
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        match &maybe_path {
            Some(model_path) => {
                LocalModel::prepare(
                    model_path.to_str().context("Invalid UTF-8 in model path")?,
                    flags.model_config.as_deref(),
                    flags.model_name.clone(),
                )
                .await?
            }
            None => {
                // echo_full engine doesn't need a path
                match &flags.model_name {
                    Some(name) => LocalModel::with_name_only(name),
                    None => Default::default(),
87
88
                }
            }
Graham King's avatar
Graham King committed
89
        }
90
    };
91
92
93
94
95
96
97
98
99
100
101

    // Only set if user provides. Usually loaded from tokenizer_config.json
    if let Some(context_length) = flags.context_length {
        local_model.set_context_length(context_length);
    }
    // Always set, there is no engine provided default
    local_model.set_kv_cache_block_size(
        flags
            .kv_cache_block_size
            .unwrap_or(DEFAULT_KV_CACHE_BLOCK_SIZE),
    );
102

103
    let mut extra: Option<Pin<Box<dyn Future<Output = ()> + Send>>> = None; // vllm and sglang sub-process
104

105
106
107
108
109
110
111
112
    let template = if let Some(path) = flags.request_template.as_ref() {
        let template = RequestTemplate::load(path)?;
        tracing::debug!("Using request template: {template:?}");
        Some(template)
    } else {
        None
    };

113
114
115
    // We may need it later
    let card = local_model.card().clone();

116
117
    let out_opt = out_opt.unwrap_or_else(|| {
        let default_engine = if card.is_gguf() {
118
            gguf_default()
119
        } else {
120
            safetensors_default()
121
122
123
124
125
126
127
128
129
        };
        tracing::info!(
            "Using default engine: {default_engine}. Use out=<engine> to specify one of {}",
            Output::available_engines().join(", ")
        );
        default_engine
    });
    print_cuda(&out_opt);

130
131
    // Create the engine matching `out`
    let engine_config = match out_opt {
132
133
134
135
136
137
138
139
140
141
        Output::Dynamic => {
            // Sanity check - TODO probably make a general sanity check at start of method
            if flags.context_length.is_some() {
                anyhow::bail!("'--content-length' flag should only be used on the worker node, not on the ingress");
            }
            if flags.kv_cache_block_size.is_some() {
                anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress");
            }
            EngineConfig::Dynamic
        }
142
143
144
145
        Output::EchoFull => EngineConfig::StaticFull {
            model: Box::new(local_model),
            engine: dynamo_llm::engines::make_engine_full(),
        },
146
        Output::EchoCore => {
147
148
            let card = local_model.card();
            if !card.has_tokenizer() {
149
150
151
152
153
                anyhow::bail!(
                    "out=echo_core need to find the tokenizer. Pass flag --model-path <path>"
                );
            };
            EngineConfig::StaticCore {
154
                engine: dynamo_llm::engines::make_engine_core(),
155
                model: Box::new(local_model),
156
157
            }
        }
158
        #[cfg(feature = "mistralrs")]
159
        Output::MistralRs => EngineConfig::StaticFull {
160
            engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
161
162
            model: Box::new(local_model),
        },
163
        Output::SgLang => {
164
165
166
167
            if !local_model.path().is_dir() {
                // TODO Does sglang support GGUF? Can we make it work?
                anyhow::bail!("`--model-path should point at a HuggingFace repo checkout");
            }
168
169
170
171
172

            // If `in=dyn` we want the sglang subprocess to listen on that endpoint.
            // If not, then the endpoint isn't exposed so we invent an internal one.
            let endpoint = match &in_opt {
                Input::Endpoint(path) => path.parse()?,
173
                _ => internal_endpoint("sglang"),
174
175
            };

176
177
178
179
180
181
            let multi_node_conf = dynamo_llm::engines::MultiNodeConfig {
                num_nodes: flags.num_nodes,
                node_rank: flags.node_rank,
                leader_addr: flags.leader_addr.clone().unwrap_or_default(),
            };
            let (py_script, child) = match subprocess::start(
182
                subprocess::sglang::PY,
183
                &local_model,
184
                &endpoint,
185
                flags.clone(),
186
187
188
189
190
                if flags.num_nodes <= 1 {
                    None
                } else {
                    Some(multi_node_conf)
                },
191
192
193
194
195
196
197
198
199
200
201
202
203
204
            )
            .await
            {
                Ok(x) => x,
                Err(err) => {
                    anyhow::bail!("Failed starting sglang sub-process: {err}");
                }
            };
            let cancel_token = cancel_token.clone();

            // Sub-process cleanup
            extra = Some(Box::pin(async move {
                stopper(cancel_token, child, py_script).await;
            }));
205
            EngineConfig::Dynamic
206
207
208
209
210
        }
        Output::Vllm => {
            if flags.base_gpu_id != 0 {
                anyhow::bail!("vllm does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
            }
211
212
213
214
215

            // If `in=dyn` we want the vllm subprocess to listen on that endpoint.
            // If not, then the endpoint isn't exposed so we invent an internal one.
            let endpoint = match &in_opt {
                Input::Endpoint(path) => path.parse()?,
216
                _ => internal_endpoint("vllm"),
217
218
            };

219
            let (py_script, child) = match subprocess::start(
220
                subprocess::vllm::PY,
221
                &local_model,
222
                &endpoint,
223
                flags.clone(),
224
                None, // multi-node config. vllm uses `ray`, see guide
225
226
227
228
229
230
231
232
233
234
235
236
237
238
            )
            .await
            {
                Ok(x) => x,
                Err(err) => {
                    anyhow::bail!("Failed starting vllm sub-process: {err}");
                }
            };
            let cancel_token = cancel_token.clone();

            // Sub-process cleanup
            extra = Some(Box::pin(async move {
                stopper(cancel_token, child, py_script).await;
            }));
239
            EngineConfig::Dynamic
240
        }
241
242
243
244
245
246
247
248
249
        Output::Trtllm => {
            if flags.base_gpu_id != 0 {
                anyhow::bail!("TRTLLM does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
            }

            // If `in=dyn` we want the trtllm subprocess to listen on that endpoint.
            // If not, then the endpoint isn't exposed so we invent an internal one.
            let endpoint = match &in_opt {
                Input::Endpoint(path) => path.parse()?,
250
                _ => internal_endpoint("trtllm"),
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
            };

            let (py_script, child) = match subprocess::start(
                subprocess::trtllm::PY,
                &local_model,
                &endpoint,
                flags.clone(),
                None, // multi-node config. trtlllm uses `mpi`, see guide
            )
            .await
            {
                Ok(x) => x,
                Err(err) => {
                    anyhow::bail!("Failed starting trtllm sub-process: {err}");
                }
            };
            let cancel_token = cancel_token.clone();

            // Sub-process cleanup
            extra = Some(Box::pin(async move {
                stopper(cancel_token, child, py_script).await;
            }));
            EngineConfig::Dynamic
        }
275

276
277
        #[cfg(feature = "llamacpp")]
        Output::LlamaCpp => {
278
            if !local_model.path().is_file() {
279
280
                anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors.");
            }
281
            let engine =
282
                dynamo_engine_llamacpp::make_engine(cancel_token.clone(), &local_model).await?;
283
            EngineConfig::StaticCore {
284
                engine,
285
                model: Box::new(local_model),
Graham King's avatar
Graham King committed
286
287
            }
        }
288
289
290
291
    };

    match in_opt {
        Input::Http => {
292
            crate::input::http::run(runtime.clone(), flags, engine_config, template).await?;
293
294
        }
        Input::Text => {
295
            crate::input::text::run(runtime.clone(), flags, None, engine_config, template).await?;
296
297
298
299
        }
        Input::Stdin => {
            let mut prompt = String::new();
            std::io::stdin().read_to_string(&mut prompt).unwrap();
300
301
302
303
304
305
306
307
            crate::input::text::run(
                runtime.clone(),
                flags,
                Some(prompt),
                engine_config,
                template,
            )
            .await?;
308
        }
309
        Input::Batch(path) => {
310
311
            crate::input::batch::run(runtime.clone(), flags, card, path, engine_config, template)
                .await?;
312
        }
313
        Input::Endpoint(path) => {
314
315
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
            crate::input::endpoint::run(distributed_runtime, path, engine_config).await?;
316
317
318
319
        }
    }

    // Allow engines to ask main thread to wait on an extra future.
320
    // We use this to stop the vllm and sglang sub-process
321
    if let Some(extra) = extra {
322
        extra.await;
323
324
325
326
    }

    Ok(())
}
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365

/// Wait for cancel_token to be cancelled, then stop the child as gracefully as possible.
/// Keeps the TempPath alive until the child is stopped.
async fn stopper(
    cancel_token: CancellationToken,
    mut child: tokio::process::Child,
    py_script: tempfile::TempPath,
) {
    cancel_token.cancelled().await;

    // Ask subprocess to stop gracefully
    if let Some(pid) = child.id() {
        unsafe { libc::kill(pid as i32, libc::SIGTERM) };
    }

    tokio::select! {
        exit = child.wait() => {
            tracing::trace!("vllm sub-process graceful exit");
            match exit {
                Ok(exit_status) if exit_status.success() => {}
                Ok(exit_status) => {
                    // This is nearly always 15 (SIGTERM)
                    tracing::trace!("vllm sub-process non-0 exit: {exit_status}");
                }
                Err(err) => {
                    tracing::warn!("vllm sub-process error getting exit status: {err}");
                }
            }
        }
        _ = tokio::time::sleep(CHILD_STOP_TIMEOUT) => {
            // It didn't stop in time, kill it
            child.kill().await.expect("Failed killing vllm subprocess");
            let _ = child.wait().await;
        }
    }
    // This temporary file contains the python script running the engine. It deletes on drop.
    // Keep it alive until the engine has stopped.
    drop(py_script);
}
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

/// If the user will benefit from CUDA/Metal/Vulkan, remind them to build with it.
/// If they have it, celebrate!
// Only mistralrs and llamacpp need to be built with CUDA.
// The Python engines only need it at runtime.
#[cfg(any(feature = "mistralrs", feature = "llamacpp"))]
fn print_cuda(output: &Output) {
    // These engines maybe be compiled in, but are they the chosen one?
    match output {
        #[cfg(feature = "mistralrs")]
        Output::MistralRs => {}
        #[cfg(feature = "llamacpp")]
        Output::LlamaCpp => {}
        _ => {
            return;
        }
    }

    #[cfg(feature = "cuda")]
    {
        tracing::info!("CUDA on");
    }
    #[cfg(feature = "metal")]
    {
        tracing::info!("Metal on");
    }
    #[cfg(feature = "vulkan")]
    {
        tracing::info!("Vulkan on");
    }
    #[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
    tracing::info!("CPU mode. Rebuild with `--features cuda|metal|vulkan` for better performance");
}

#[cfg(not(any(feature = "mistralrs", feature = "llamacpp")))]
fn print_cuda(_output: &Output) {}
402
403
404

fn gguf_default() -> Output {
    #[cfg(feature = "llamacpp")]
405
406
407
    {
        Output::LlamaCpp
    }
408
409

    #[cfg(all(feature = "mistralrs", not(feature = "llamacpp")))]
410
411
412
    {
        Output::MistralRs
    }
413
414

    #[cfg(not(any(feature = "mistralrs", feature = "llamacpp")))]
415
416
417
    {
        Output::EchoFull
    }
418
419
420
421
}

fn safetensors_default() -> Output {
    #[cfg(feature = "mistralrs")]
422
423
424
    {
        Output::MistralRs
    }
425
426

    #[cfg(not(feature = "mistralrs"))]
427
428
429
430
431
432
433
434
435
436
437
438
439
    {
        Output::EchoFull
    }
}

/// A random endpoint to use for internal communication
/// We can't hard code because we may be running several on the same machine (GPUs 0-3 and 4-7)
fn internal_endpoint(engine: &str) -> EndpointId {
    EndpointId {
        namespace: Slug::slugify(&uuid::Uuid::new_v4().to_string()).to_string(),
        component: engine.to_string(),
        name: "generate".to_string(),
    }
440
}