main.rs 10.4 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

Ryan Olson's avatar
Ryan Olson committed
6
7
use clap::{Parser, Subcommand};

8
9
use dynamo_llm::discovery::{ModelManager, ModelWatcher};
use dynamo_llm::local_model::{LocalModel, ModelNetworkName};
10
use dynamo_llm::model_type::ModelType;
11
12
use dynamo_runtime::component::Endpoint;
use dynamo_runtime::pipeline::RouterMode;
Neelay Shah's avatar
Neelay Shah committed
13
use dynamo_runtime::{
14
    distributed::DistributedConfig, logging, DistributedRuntime, Result, Runtime, Worker,
Ryan Olson's avatar
Ryan Olson committed
15
16
};

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
// Macro to define model types and associated commands
macro_rules! define_type_subcommands {
    ($(($variant:ident, $primary_name:expr, [$($alias:expr),*], $help:expr)),* $(,)?) => {
        #[derive(Subcommand)]
        enum AddCommands {
            $(
                #[doc = $help]
                #[command(name = $primary_name, aliases = [$($alias),*])]
                $variant(AddModelArgs),
            )*
        }

        #[derive(Subcommand)]
        enum ListCommands {
            $(
                #[doc = concat!("List ", $primary_name, " models")]
                #[command(name = $primary_name, aliases = [$($alias),*])]
                $variant,
            )*
        }

        #[derive(Subcommand)]
        enum RemoveCommands {
            $(
                #[doc = concat!("Remove ", $primary_name, " model")]
                #[command(name = $primary_name, aliases = [$($alias),*])]
                $variant(RemoveModelArgs),
            )*
        }

        impl AddCommands {
            fn into_parts(self) -> (ModelType, String, String) {
                match self {
                    $(Self::$variant(args) => (ModelType::$variant, args.model_name, args.endpoint_name)),*
                }
            }
        }

        impl RemoveCommands {
            fn into_parts(self) -> (ModelType, String) {
                match self {
                    $(Self::$variant(args) => (ModelType::$variant, args.model_name)),*
                }
            }
        }

        impl ListCommands {
            fn model_type(&self) -> ModelType {
                match self {
                    $(Self::$variant => ModelType::$variant),*
                }
            }
        }
    }
}

define_type_subcommands!(
    (
        Chat,
        "chat",
        ["chat-model", "chat-models"],
        "Add a chat model"
    ),
    (
        Completion,
        "completion",
        ["completions", "completion-model"],
        "Add a completion model"
    ),
    // Add new model types here:
87
88
89
90
91
92
    (
        Embedding,
        "embedding",
        ["embeddings", "embedding-model"],
        "Add an embedding model"
    )
93
94
);

Ryan Olson's avatar
Ryan Olson committed
95
#[derive(Parser)]
96
97
98
#[command(
    author="NVIDIA",
    version="0.2.1",
99
    about="LLMCTL - Deprecated. Do not use.",
100
101
102
    long_about = None,
    disable_help_subcommand = true,
)]
Ryan Olson's avatar
Ryan Olson committed
103
struct Cli {
104
    /// Public Namespace to operate in
105
    /// Do not use this. In fact don't use anything about this file.
Ryan Olson's avatar
Ryan Olson committed
106
    #[arg(short = 'n', long)]
107
    public_namespace: Option<String>,
Ryan Olson's avatar
Ryan Olson committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

    #[command(subcommand)]
    command: Commands,
}

#[derive(Subcommand)]
enum Commands {
    /// HTTP service related commands
    Http {
        #[command(subcommand)]
        command: HttpCommands,
    },
}

#[derive(Subcommand)]
enum HttpCommands {
124
    /// Add models
Ryan Olson's avatar
Ryan Olson committed
125
    Add {
126
127
        #[command(subcommand)]
        model_type: AddCommands,
Ryan Olson's avatar
Ryan Olson committed
128
129
    },

130
    /// List models (all types if no specific type provided)
Ryan Olson's avatar
Ryan Olson committed
131
    List {
132
133
        #[command(subcommand)]
        model_type: Option<ListCommands>,
Ryan Olson's avatar
Ryan Olson committed
134
135
    },

136
    /// Remove models
Ryan Olson's avatar
Ryan Olson committed
137
    Remove {
138
139
        #[command(subcommand)]
        model_type: RemoveCommands,
Ryan Olson's avatar
Ryan Olson committed
140
141
142
    },
}

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#[derive(Parser)]
struct AddModelArgs {
    /// Model name (e.g. foo/v1)
    #[arg(name = "model-name")]
    model_name: String,
    /// Endpoint name (format: component.endpoint or namespace.component.endpoint)
    #[arg(name = "endpoint-name")]
    endpoint_name: String,
}

/// Common fields for removing any model type
#[derive(Parser)]
struct RemoveModelArgs {
    /// Name of the model to remove
    #[arg(name = "model-name")]
    model_name: String,
Ryan Olson's avatar
Ryan Olson committed
159
160
161
162
163
164
}

fn main() -> Result<()> {
    logging::init();
    let cli = Cli::parse();

165
166
    // Default namespace to "dynamo" if not specified
    let namespace = cli.public_namespace.unwrap_or_else(|| "dynamo".to_string());
Ryan Olson's avatar
Ryan Olson committed
167
168
169
170
171
172
173
174
175
176
177
178

    let worker = Worker::from_settings()?;
    worker.execute(|runtime| async move { handle_command(runtime, namespace, cli.command).await })
}

async fn handle_command(runtime: Runtime, namespace: String, command: Commands) -> Result<()> {
    let settings = DistributedConfig::for_cli();
    let distributed = DistributedRuntime::new(runtime, settings).await?;

    match command {
        Commands::Http { command } => {
            match command {
179
180
181
182
183
184
                HttpCommands::Add { model_type } => {
                    let (model_type, model_name, endpoint_name) = model_type.into_parts();
                    add_model(
                        &distributed,
                        namespace.to_string(),
                        model_type,
Ryan Olson's avatar
Ryan Olson committed
185
                        model_name,
186
187
188
                        &endpoint_name,
                    )
                    .await?;
Ryan Olson's avatar
Ryan Olson committed
189
                }
190
191
192
193
194
195
196
197
198
                HttpCommands::List { model_type } => {
                    match model_type {
                        Some(model_type) => {
                            list_models(
                                &distributed,
                                namespace.clone(),
                                Some(model_type.model_type()),
                            )
                            .await?;
Ryan Olson's avatar
Ryan Olson committed
199
                        }
200
201
202
                        None => {
                            // List all model types
                            list_models(&distributed, namespace.clone(), None).await?;
Ryan Olson's avatar
Ryan Olson committed
203
204
205
                        }
                    }
                }
206
                HttpCommands::Remove { model_type } => {
207
208
                    let (model_type, name) = model_type.into_parts();
                    remove_model(&distributed, model_type, &name).await?;
209
210
211
212
213
214
215
216
217
218
219
220
221
222
                }
            }
        }
    }
    Ok(())
}

async fn add_model(
    distributed: &DistributedRuntime,
    namespace: String,
    model_type: ModelType,
    model_name: String,
    endpoint_name: &str,
) -> Result<()> {
223
    tracing::debug!("Adding model {model_name} with endpoint {endpoint_name}");
224
    if model_name.starts_with('/') {
225
        anyhow::bail!("Model name '{model_name}' cannot start with a slash");
226
227
    }

228
    let endpoint = endpoint_from_name(distributed, &namespace, endpoint_name)?;
229

230
231
    let mut model = LocalModel::with_name_only(&model_name);
    model.attach(&endpoint, model_type).await?;
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254

    Ok(())
}

#[derive(tabled::Tabled)]
struct ModelRow {
    #[tabled(rename = "MODEL TYPE")]
    model_type: String,
    #[tabled(rename = "MODEL NAME")]
    name: String,
    #[tabled(rename = "NAMESPACE")]
    namespace: String,
    #[tabled(rename = "COMPONENT")]
    component: String,
    #[tabled(rename = "ENDPOINT")]
    endpoint: String,
}

async fn list_models(
    distributed: &DistributedRuntime,
    namespace: String,
    model_type: Option<ModelType>,
) -> Result<()> {
255
256
257
258
259
260
    // We only need a ModelWatcher to call it's all_entries. llmctl is going away so no need to
    // refactor for this.
    let watcher = ModelWatcher::new(
        distributed.clone(),
        Arc::new(ModelManager::new()),
        RouterMode::Random,
261
        None,
262
    );
263
264

    let mut models = Vec::new();
265
266
267
268
269
270
271
272
273
274
275
    for entry in watcher.all_entries().await? {
        match (model_type, entry.model_type) {
            (None, _) => {
                // list all
            }
            (Some(want), got) if want == got => {
                // match
            }
            _ => {
                // no match
                continue;
Ryan Olson's avatar
Ryan Olson committed
276
277
            }
        }
278
279
280
281
282
283
284
        models.push(ModelRow {
            model_type: entry.model_type.as_str().to_string(),
            name: entry.name,
            namespace: entry.endpoint.namespace,
            component: entry.endpoint.component,
            endpoint: entry.endpoint.name,
        });
Ryan Olson's avatar
Ryan Olson committed
285
286
    }

287
288
289
    if models.is_empty() {
        match &model_type {
            Some(mt) => println!(
290
                "No {} models found in namespace: {}",
291
292
293
                mt.as_str(),
                namespace
            ),
294
            None => println!("No models found in namespace: {}", namespace),
295
296
297
298
        }
    } else {
        let table = tabled::Table::new(models);
        match &model_type {
299
300
            Some(mt) => println!("Listing {} models in namespace: {}", mt.as_str(), namespace),
            None => println!("Listing all models in namespace: {}", namespace),
301
302
303
304
305
306
        }
        println!("{}", table);
    }
    Ok(())
}

307
308
309
310
311
async fn remove_model(
    distributed: &DistributedRuntime,
    model_type: ModelType,
    model_name: &str,
) -> Result<()> {
312
313
314
315
316
    // We have to do this manually because normally the etcd lease system does it for us
    let watcher = ModelWatcher::new(
        distributed.clone(),
        Arc::new(ModelManager::new()),
        RouterMode::Random,
317
        None,
318
    );
319
320
321
322
    let Some(etcd_client) = distributed.etcd_client() else {
        anyhow::bail!("llmctl is only useful with dynamic workers");
    };
    let active_instances = watcher.entries_for_model(model_name).await?;
323
324
325
326
    for entry in active_instances
        .into_iter()
        .filter(|entry| entry.model_type == model_type)
    {
327
328
329
330
331
        let network_name = ModelNetworkName::from_entry(&entry, 0);
        tracing::debug!("deleting key: {network_name}");
        etcd_client
            .kv_delete(network_name.to_string(), None)
            .await?;
332
    }
333

Ryan Olson's avatar
Ryan Olson committed
334
335
    Ok(())
}
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

fn endpoint_from_name(
    distributed: &DistributedRuntime,
    namespace: &str,
    endpoint_name: &str,
) -> anyhow::Result<Endpoint> {
    let parts: Vec<&str> = endpoint_name.split('.').collect();

    if parts.len() < 2 {
        anyhow::bail!("Endpoint name '{}' is too short. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
    } else if parts.len() > 3 {
        anyhow::bail!("Endpoint name '{}' is too long. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
    }

    // TODO previous version sometime hardcoded this to "http", so maybe adjust
    let component_name = parts[parts.len() - 2].to_string();
    let endpoint_name = parts[parts.len() - 1].to_string();

    let component = distributed
        .namespace(namespace)?
        .component(component_name)?;

    Ok(component.endpoint(endpoint_name))
}