lib.rs 5.56 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use anyhow::Context as _;
use dynamo_llm::entrypoint::EngineConfig;
6
use dynamo_llm::entrypoint::input::Input;
7
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
8
use dynamo_runtime::distributed::DistributedConfig;
9
use dynamo_runtime::{DistributedRuntime, Runtime};
10

11
mod flags;
12
use either::Either;
13
pub use flags::Flags;
14
mod opt;
15
pub use dynamo_llm::request_template::RequestTemplate;
16
pub use opt::Output;
17

18
pub async fn run(
19
    runtime: Runtime,
20
    in_opt: Input,
21
    out_opt: Option<Output>,
Graham King's avatar
Graham King committed
22
    mut flags: Flags,
23
) -> anyhow::Result<()> {
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    //
    // Configure
    //

    let mut builder = LocalModelBuilder::default();
    builder
        .model_path(
            flags
                .model_path_pos
                .clone()
                .or(flags.model_path_flag.clone()),
        )
        .model_name(flags.model_name.clone())
        .kv_cache_block_size(flags.kv_cache_block_size)
        // Only set if user provides. Usually loaded from tokenizer_config.json
        .context_length(flags.context_length)
Graham King's avatar
Graham King committed
40
41
42
        .http_port(flags.http_port)
        .tls_cert_path(flags.tls_cert_path.take())
        .tls_key_path(flags.tls_key_path.take())
43
        .router_config(Some(flags.router_config()))
44
        .request_template(flags.request_template.clone())
45
46
        .migration_limit(flags.migration_limit)
        .is_mocker(matches!(out_opt, Some(Output::Mocker)));
47

48
    // TODO: old, address this later:
49
50
    // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
    // If not, then the endpoint isn't exposed so we let LocalModel invent one.
51
    let mut rt = Either::Left(runtime.clone());
52
    if let Input::Endpoint(path) = &in_opt {
53
54
        builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?));

55
56
        let dst_config = DistributedConfig::from_settings(flags.static_worker);
        let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
57
        rt = Either::Right(distributed_runtime);
58
    };
59
60
61
    if let Some(Output::Static(path)) = &out_opt {
        builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?));
    }
62

63
    let local_model = builder.build().await?;
64

65
66
67
    //
    // Create an engine
    //
68

69
    let out_opt = out_opt.unwrap_or_else(|| default_engine_for(&local_model));
70
71
    print_cuda(&out_opt);

72
    // Now that we know the output we're targeting, check if we expect it to work
73
    flags.validate(&in_opt, &out_opt)?;
74

75
    // Make an engine from the local_model, flags and output.
76
77
78
    let engine_config = engine_for(out_opt, flags.clone(), local_model, rt.clone()).await?;

    // Run it from an input
79
    dynamo_llm::entrypoint::input::run_input(rt, in_opt, engine_config).await?;
80

81
82
    Ok(())
}
83

84
85
86
87
88
89
/// Create the engine matching `out_opt`
/// Note validation happens in Flags::validate. In here assume everything is going to work.
async fn engine_for(
    out_opt: Output,
    flags: Flags,
    local_model: LocalModel,
90
    rt: Either<Runtime, DistributedRuntime>,
91
) -> anyhow::Result<EngineConfig> {
92
    match out_opt {
93
94
95
96
97
98
99
100
        Output::Auto => {
            // Auto-discover backends
            Ok(EngineConfig::Dynamic(Box::new(local_model)))
        }
        Output::Static(_) => {
            // A single static backend, no etcd
            Ok(EngineConfig::StaticRemote(Box::new(local_model)))
        }
101
        Output::Echo => Ok(EngineConfig::StaticFull {
102
            model: Box::new(local_model),
103
            engine: dynamo_llm::engines::make_echo_engine(),
104
            is_static: flags.static_worker,
105
        }),
106
        #[cfg(feature = "mistralrs")]
107
108
109
        Output::MistralRs => Ok(EngineConfig::StaticFull {
            engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
            model: Box::new(local_model),
110
            is_static: flags.static_worker,
111
        }),
112
113
114
115
116
117
118
119
120
121
122
        Output::Mocker => {
            let Either::Right(drt) = rt else {
                panic!("Mocker requires a distributed runtime to run.");
            };

            let args = flags.mocker_config();
            let endpoint = local_model.endpoint_id().clone();

            let engine =
                dynamo_llm::mocker::engine::make_mocker_engine(drt, endpoint, args).await?;

123
124
125
            Ok(EngineConfig::StaticCore {
                engine,
                model: Box::new(local_model),
126
                is_static: flags.static_worker,
127
            })
128
129
130
        }
    }
}
131

132
/// If the user will benefit from CUDA or Metal, remind them to build with it.
133
/// If they have it, celebrate!
134
// Only mistralrs needs to be built with CUDA.
135
// The Python engines only need it at runtime.
136
#[cfg(feature = "mistralrs")]
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
fn print_cuda(output: &Output) {
    // These engines maybe be compiled in, but are they the chosen one?
    match output {
        #[cfg(feature = "mistralrs")]
        Output::MistralRs => {}
        _ => {
            return;
        }
    }

    #[cfg(feature = "cuda")]
    {
        tracing::info!("CUDA on");
    }
    #[cfg(feature = "metal")]
    {
        tracing::info!("Metal on");
    }
155
156
    #[cfg(not(any(feature = "cuda", feature = "metal")))]
    tracing::info!("CPU mode. Rebuild with `--features cuda|metal` for better performance");
157
158
}

159
#[cfg(not(feature = "mistralrs"))]
160
fn print_cuda(_output: &Output) {}
161

162
163
fn default_engine_for(_local_model: &LocalModel) -> Output {
    safetensors_default()
164
165
166
167
}

fn safetensors_default() -> Output {
    #[cfg(feature = "mistralrs")]
168
169
170
    {
        Output::MistralRs
    }
171
172

    #[cfg(not(feature = "mistralrs"))]
173
    {
174
        Output::Echo
175
176
    }
}