lib.rs 7.11 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
6
7
use anyhow::Context as _;
use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::entrypoint::EngineConfig;
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
8
use dynamo_runtime::distributed::DistributedConfig;
9
use dynamo_runtime::CancellationToken;
10
use dynamo_runtime::{DistributedRuntime, Runtime};
11

12
mod flags;
13
use either::Either;
14
pub use flags::Flags;
15
mod opt;
16
pub use dynamo_llm::request_template::RequestTemplate;
17
pub use opt::Output;
18

19
pub async fn run(
20
    runtime: Runtime,
21
    in_opt: Input,
22
    out_opt: Option<Output>,
Graham King's avatar
Graham King committed
23
    mut flags: Flags,
24
) -> anyhow::Result<()> {
25
26
27
28
29
30
31
32
33
34
35
36
37
    //
    // Configure
    //

    let mut builder = LocalModelBuilder::default();
    builder
        .model_path(
            flags
                .model_path_pos
                .clone()
                .or(flags.model_path_flag.clone()),
        )
        .model_name(flags.model_name.clone())
38
        .model_config(flags.model_config.clone())
39
40
41
        .kv_cache_block_size(flags.kv_cache_block_size)
        // Only set if user provides. Usually loaded from tokenizer_config.json
        .context_length(flags.context_length)
Graham King's avatar
Graham King committed
42
43
44
        .http_port(flags.http_port)
        .tls_cert_path(flags.tls_cert_path.take())
        .tls_key_path(flags.tls_key_path.take())
45
        .router_config(Some(flags.router_config()))
46
        .request_template(flags.request_template.clone())
47
48
        .migration_limit(flags.migration_limit)
        .is_mocker(matches!(out_opt, Some(Output::Mocker)));
49

50
    // TODO: old, address this later:
51
52
    // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
    // If not, then the endpoint isn't exposed so we let LocalModel invent one.
53
    let mut rt = Either::Left(runtime.clone());
54
    if let Input::Endpoint(path) = &in_opt {
55
56
        builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?));

57
58
        let dst_config = DistributedConfig::from_settings(flags.static_worker);
        let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
59
        rt = Either::Right(distributed_runtime);
60
    };
61
62
63
    if let Some(Output::Static(path)) = &out_opt {
        builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?));
    }
64

65
    let local_model = builder.build().await?;
66

67
68
69
    //
    // Create an engine
    //
70

71
    let out_opt = out_opt.unwrap_or_else(|| default_engine_for(&local_model));
72
73
    print_cuda(&out_opt);

74
    // Now that we know the output we're targeting, check if we expect it to work
75
    flags.validate(&local_model, &in_opt, &out_opt)?;
76

77
    // Make an engine from the local_model, flags and output.
78
    let engine_config = engine_for(
79
80
81
82
83
84
85
        runtime.primary_token(),
        out_opt,
        flags.clone(),
        local_model,
        rt.clone(),
    )
    .await?;
86

87
88
89
    //
    // Run in from an input
    //
90
    dynamo_llm::entrypoint::input::run_input(rt, in_opt, engine_config).await?;
91

92
93
    Ok(())
}
94

95
96
97
98
99
100
101
/// Create the engine matching `out_opt`
/// Note validation happens in Flags::validate. In here assume everything is going to work.
async fn engine_for(
    cancel_token: CancellationToken,
    out_opt: Output,
    flags: Flags,
    local_model: LocalModel,
102
    rt: Either<Runtime, DistributedRuntime>,
103
) -> anyhow::Result<EngineConfig> {
104
    match out_opt {
105
106
107
108
109
110
111
112
        Output::Auto => {
            // Auto-discover backends
            Ok(EngineConfig::Dynamic(Box::new(local_model)))
        }
        Output::Static(_) => {
            // A single static backend, no etcd
            Ok(EngineConfig::StaticRemote(Box::new(local_model)))
        }
113
114
115
        Output::EchoFull => Ok(EngineConfig::StaticFull {
            model: Box::new(local_model),
            engine: dynamo_llm::engines::make_engine_full(),
116
            is_static: flags.static_worker,
117
118
119
120
        }),
        Output::EchoCore => Ok(EngineConfig::StaticCore {
            engine: dynamo_llm::engines::make_engine_core(),
            model: Box::new(local_model),
121
            is_static: flags.static_worker,
122
        }),
123
        #[cfg(feature = "mistralrs")]
124
125
126
        Output::MistralRs => Ok(EngineConfig::StaticFull {
            engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
            model: Box::new(local_model),
127
            is_static: flags.static_worker,
128
        }),
129
        #[cfg(feature = "llamacpp")]
130
131
132
        Output::LlamaCpp => Ok(EngineConfig::StaticCore {
            engine: dynamo_engine_llamacpp::make_engine(cancel_token, &local_model).await?,
            model: Box::new(local_model),
133
            is_static: flags.static_worker,
134
        }),
135
136
137
138
139
140
141
142
143
144
145
        Output::Mocker => {
            let Either::Right(drt) = rt else {
                panic!("Mocker requires a distributed runtime to run.");
            };

            let args = flags.mocker_config();
            let endpoint = local_model.endpoint_id().clone();

            let engine =
                dynamo_llm::mocker::engine::make_mocker_engine(drt, endpoint, args).await?;

146
147
148
            Ok(EngineConfig::StaticCore {
                engine,
                model: Box::new(local_model),
149
                is_static: flags.static_worker,
150
            })
151
152
153
        }
    }
}
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

/// If the user will benefit from CUDA/Metal/Vulkan, remind them to build with it.
/// If they have it, celebrate!
// Only mistralrs and llamacpp need to be built with CUDA.
// The Python engines only need it at runtime.
#[cfg(any(feature = "mistralrs", feature = "llamacpp"))]
fn print_cuda(output: &Output) {
    // These engines maybe be compiled in, but are they the chosen one?
    match output {
        #[cfg(feature = "mistralrs")]
        Output::MistralRs => {}
        #[cfg(feature = "llamacpp")]
        Output::LlamaCpp => {}
        _ => {
            return;
        }
    }

    #[cfg(feature = "cuda")]
    {
        tracing::info!("CUDA on");
    }
    #[cfg(feature = "metal")]
    {
        tracing::info!("Metal on");
    }
    #[cfg(feature = "vulkan")]
    {
        tracing::info!("Vulkan on");
    }
    #[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
    tracing::info!("CPU mode. Rebuild with `--features cuda|metal|vulkan` for better performance");
}

#[cfg(not(any(feature = "mistralrs", feature = "llamacpp")))]
fn print_cuda(_output: &Output) {}
190

191
192
193
194
195
196
197
198
199
200
201
202
203
fn default_engine_for(local_model: &LocalModel) -> Output {
    let default_engine = if local_model.card().is_gguf() {
        gguf_default()
    } else {
        safetensors_default()
    };
    tracing::info!(
        "Using default engine: {default_engine}. Use out=<engine> to specify one of {}",
        Output::available_engines().join(", ")
    );
    default_engine
}

204
205
fn gguf_default() -> Output {
    #[cfg(feature = "llamacpp")]
206
207
208
    {
        Output::LlamaCpp
    }
209
210

    #[cfg(all(feature = "mistralrs", not(feature = "llamacpp")))]
211
212
213
    {
        Output::MistralRs
    }
214
215

    #[cfg(not(any(feature = "mistralrs", feature = "llamacpp")))]
216
217
218
    {
        Output::EchoFull
    }
219
220
221
222
}

fn safetensors_default() -> Output {
    #[cfg(feature = "mistralrs")]
223
224
225
    {
        Output::MistralRs
    }
226
227

    #[cfg(not(feature = "mistralrs"))]
228
229
230
231
    {
        Output::EchoFull
    }
}