lib.rs 6.94 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use anyhow::Context as _;
use dynamo_llm::entrypoint::EngineConfig;
6
use dynamo_llm::entrypoint::input::Input;
7
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
8
use dynamo_runtime::distributed::{DistributedConfig, RequestPlaneMode};
9
use dynamo_runtime::storage::kv;
10
use dynamo_runtime::transports::nats;
11
use dynamo_runtime::{DistributedRuntime, Runtime};
12

13
14
mod flags;
pub use flags::Flags;
15
mod opt;
16
pub use dynamo_llm::request_template::RequestTemplate;
17
pub use opt::Output;
18

19
pub async fn run(
20
    runtime: Runtime,
21
    in_opt: Input,
22
    out_opt: Option<Output>,
Graham King's avatar
Graham King committed
23
    mut flags: Flags,
24
) -> anyhow::Result<()> {
25
26
27
28
29
30
31
32
    //
    // Download
    //

    let maybe_remote_repo = flags
        .model_path_pos
        .clone()
        .or_else(|| flags.model_path_flag.clone());
33
34
35
36

    // Preserve the original model identifier before downloading (for default model name)
    let original_model_identifier = maybe_remote_repo.as_ref().map(|p| p.display().to_string());

37
38
39
40
41
42
43
44
    let model_path = match maybe_remote_repo {
        None => None,
        Some(p) if p.exists() => {
            // Already a local path
            Some(p)
        }
        Some(p) => {
            // model_path might be an HF repo, not a local path. Resolve it by downloading.
45
46
47
            // Mocker only needs tokenizer, not weights
            let ignore_weights = matches!(out_opt, Some(Output::Mocker));
            Some(LocalModel::fetch(&p.display().to_string(), ignore_weights).await?)
48
49
50
        }
    };

51
52
53
54
55
56
    //
    // Configure
    //

    let mut builder = LocalModelBuilder::default();
    builder
57
        .model_name(flags.model_name.clone().or(original_model_identifier))
58
59
60
        .kv_cache_block_size(flags.kv_cache_block_size)
        // Only set if user provides. Usually loaded from tokenizer_config.json
        .context_length(flags.context_length)
Graham King's avatar
Graham King committed
61
62
63
        .http_port(flags.http_port)
        .tls_cert_path(flags.tls_cert_path.take())
        .tls_key_path(flags.tls_key_path.take())
64
        .router_config(Some(flags.router_config()))
65
        .request_template(flags.request_template.clone())
66
67
        .migration_limit(flags.migration_limit)
        .is_mocker(matches!(out_opt, Some(Output::Mocker)));
68

69
70
71
72
73
    // Only the worker has a model path
    if let Some(model_path) = model_path {
        builder.model_path(model_path);
    }

74
    // TODO: old, address this later:
75
76
77
    // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
    // If not, then the endpoint isn't exposed so we let LocalModel invent one.
    if let Input::Endpoint(path) = &in_opt {
78
        builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?));
79
    }
80
81
82
83
84
    let dst_config = if is_process_local(&in_opt, &out_opt) {
        // We are both the frontend and backend, no networking
        DistributedConfig::process_local()
    } else {
        // Normal case
85
        let selected_store: kv::Selector = flags.store_kv.parse()?;
86
87
88
89
90
91
92
93
94
95
96
        let request_plane: RequestPlaneMode = flags.request_plane.parse()?;
        DistributedConfig {
            store_backend: selected_store,
            // We only need NATS here to monitor it's metrics, so only if it's our request plane.
            nats_config: if request_plane.is_nats() {
                Some(nats::ClientOptions::default())
            } else {
                None
            },
            request_plane,
        }
97
    };
98
    let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
99
    let local_model = builder.build().await?;
100

101
102
103
    //
    // Create an engine
    //
104

105
    let out_opt = out_opt.unwrap_or_else(|| default_engine_for(&local_model));
106
107
    print_cuda(&out_opt);

108
    // Now that we know the output we're targeting, check if we expect it to work
109
    flags.validate(&out_opt)?;
110

111
    // Make an engine from the local_model, flags and output.
112
113
114
115
116
117
118
    let engine_config = engine_for(
        out_opt,
        flags.clone(),
        local_model,
        distributed_runtime.clone(),
    )
    .await?;
119
120

    // Run it from an input
121
    dynamo_llm::entrypoint::input::run_input(distributed_runtime, in_opt, engine_config).await?;
122

123
124
    Ok(())
}
125

126
127
128
129
130
131
132
133
134
135
136
137
pub fn is_in_dynamic(in_opt: &Input) -> bool {
    matches!(in_opt, Input::Endpoint(_))
}

pub fn is_out_dynamic(out_opt: &Option<Output>) -> bool {
    matches!(out_opt, Some(Output::Auto))
}

fn is_process_local(in_opt: &Input, out_opt: &Option<Output>) -> bool {
    !is_in_dynamic(in_opt) && !is_out_dynamic(out_opt)
}

138
139
140
141
142
143
/// Create the engine matching `out_opt`
/// Note validation happens in Flags::validate. In here assume everything is going to work.
async fn engine_for(
    out_opt: Output,
    flags: Flags,
    local_model: LocalModel,
144
    drt: DistributedRuntime,
145
) -> anyhow::Result<EngineConfig> {
146
    match out_opt {
147
148
        Output::Auto => {
            // Auto-discover backends
149
150
151
152
            Ok(EngineConfig::Dynamic {
                model: Box::new(local_model),
                engine_factory: None,
            })
153
        }
154
        Output::Echo => Ok(EngineConfig::InProcessText {
155
            model: Box::new(local_model),
156
            engine: dynamo_llm::engines::make_echo_engine(),
157
        }),
158
        #[cfg(feature = "mistralrs")]
159
        Output::MistralRs => Ok(EngineConfig::InProcessText {
160
161
162
            engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
            model: Box::new(local_model),
        }),
163
164
165
166
167
168
169
        Output::Mocker => {
            let args = flags.mocker_config();
            let endpoint = local_model.endpoint_id().clone();

            let engine =
                dynamo_llm::mocker::engine::make_mocker_engine(drt, endpoint, args).await?;

170
            Ok(EngineConfig::InProcessTokens {
171
172
                engine,
                model: Box::new(local_model),
173
                is_prefill: false,
174
            })
175
176
177
        }
    }
}
178

179
/// If the user will benefit from CUDA or Metal, remind them to build with it.
180
/// If they have it, celebrate!
181
// Only mistralrs needs to be built with CUDA.
182
// The Python engines only need it at runtime.
183
#[cfg(feature = "mistralrs")]
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
fn print_cuda(output: &Output) {
    // These engines maybe be compiled in, but are they the chosen one?
    match output {
        #[cfg(feature = "mistralrs")]
        Output::MistralRs => {}
        _ => {
            return;
        }
    }

    #[cfg(feature = "cuda")]
    {
        tracing::info!("CUDA on");
    }
    #[cfg(feature = "metal")]
    {
        tracing::info!("Metal on");
    }
202
203
    #[cfg(not(any(feature = "cuda", feature = "metal")))]
    tracing::info!("CPU mode. Rebuild with `--features cuda|metal` for better performance");
204
205
}

206
#[cfg(not(feature = "mistralrs"))]
207
fn print_cuda(_output: &Output) {}
208

209
210
fn default_engine_for(_local_model: &LocalModel) -> Output {
    safetensors_default()
211
212
213
214
}

fn safetensors_default() -> Output {
    #[cfg(feature = "mistralrs")]
215
216
217
    {
        Output::MistralRs
    }
218
219

    #[cfg(not(feature = "mistralrs"))]
220
    {
221
        Output::Echo
222
223
    }
}