Unverified Commit aeb79e62 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Support multiple models on single ingress node (#1127)

We can now do this:

- Node 1:

```
dynamo-run in=http out=dyn
```

- Node 2 and 3, two instances of component 'backend' in the nemotron_ultra pipeline:

```
dynamo-run in=dyn://nemotron_ultra.backend.generate out=vllm /data/models/NemotronUltra
```

- Node 4 and 5, two instances of the 'backend' component in nemotron_super pipeline:

```
dynamo-run in=dyn://nemotron_super.backend.generate out=vllm /data/models/NemotronSuper
```

The ingress node will discover all four instances and route correctly. We have been planning for this for a long time now.

As part of this auto-discovery is now always `out=dyn`, with no extra URL parts. Previously it could only route to a single pipeline.

Also:
- Refactor endpoint / instance naming now that I understand them
- Fix removing models when their instance stops.
parent 74221fd7
...@@ -4,13 +4,10 @@ ...@@ -4,13 +4,10 @@
use clap::Parser; use clap::Parser;
use std::sync::Arc; use std::sync::Arc;
use dynamo_llm::{ use dynamo_llm::http::service::{discovery::ModelWatcher, service_v2::HttpService};
http::service::{discovery::ModelWatcher, service_v2::HttpService},
model_type::ModelType,
};
use dynamo_runtime::{ use dynamo_runtime::{
logging, pipeline::RouterMode, transports::etcd::PrefixWatcher, DistributedRuntime, Result, component, logging, pipeline::RouterMode, transports::etcd::PrefixWatcher, DistributedRuntime,
Runtime, Worker, Result, Runtime, Worker,
}; };
#[derive(Parser)] #[derive(Parser)]
...@@ -58,36 +55,19 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -58,36 +55,19 @@ async fn app(runtime: Runtime) -> Result<()> {
// written to etcd // written to etcd
// the cli when operating on an `http` component will validate the namespace.component is // the cli when operating on an `http` component will validate the namespace.component is
// registered with HttpServiceComponentDefinition // registered with HttpServiceComponentDefinition
let component = distributed
.namespace(&args.namespace)?
.component(&args.component)?;
let etcd_root = component.etcd_path();
// TODO: A single watcher already watches all model types and does the right thing.
// The paths need change here and in llmctl to not include the model_type
// Create watchers for `Chat`, `Completion`, and `Embedding` model types
for model_type in [ModelType::Chat, ModelType::Completion, ModelType::Embedding] {
let etcd_path = format!("{}/models/{}/", etcd_root, model_type.as_str());
let watch_obj = Arc::new( let watch_obj = Arc::new(
ModelWatcher::new( ModelWatcher::new(distributed.clone(), manager.clone(), RouterMode::Random).await?,
component.clone(),
manager.clone(),
&etcd_path,
RouterMode::Random,
)
.await?,
); );
if let Some(etcd_client) = distributed.etcd_client() { if let Some(etcd_client) = distributed.etcd_client() {
let models_watcher: PrefixWatcher = let models_watcher: PrefixWatcher = etcd_client
etcd_client.kv_get_and_watch_prefix(etcd_path).await?; .kv_get_and_watch_prefix(component::MODEL_ROOT_PATH)
.await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
tokio::spawn(watch_obj.watch(receiver)); tokio::spawn(watch_obj.watch(receiver));
} }
}
// Run the service // Run the service
http_service.run(runtime.child_token()).await http_service.run(runtime.child_token()).await
......
...@@ -123,7 +123,7 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -123,7 +123,7 @@ async fn app(runtime: Runtime) -> Result<()> {
let component = namespace.component("count")?; let component = namespace.component("count")?;
// Create unique instance of Count // Create unique instance of Count
let key = format!("{}/instance", component.etcd_path()); let key = format!("{}/instance", component.etcd_root());
tracing::debug!("Creating unique instance of Count at {key}"); tracing::debug!("Creating unique instance of Count at {key}");
drt.etcd_client() drt.etcd_client()
.expect("Unreachable because of DistributedRuntime::from_settings above") .expect("Unreachable because of DistributedRuntime::from_settings above")
......
...@@ -158,7 +158,7 @@ class LocalConnector(PlannerConnector): ...@@ -158,7 +158,7 @@ class LocalConnector(PlannerConnector):
# We add a custom component name to ensure that the lease is attatched to this specific watcher # We add a custom component name to ensure that the lease is attatched to this specific watcher
full_cmd = f"{base_cmd} --worker-env '{worker_env_arg}' --custom-component-name '{watcher_name}'" full_cmd = f"{base_cmd} --worker-env '{worker_env_arg}' --custom-component-name '{watcher_name}'"
pre_add_endpoint_ids = await self._get_endpoint_ids(component_name) pre_add_endpoint_ids = await self._count_instance_ids(component_name)
logger.info(f"Pre-add endpoint IDs: {pre_add_endpoint_ids}") logger.info(f"Pre-add endpoint IDs: {pre_add_endpoint_ids}")
logger.info(f"Adding watcher {watcher_name}") logger.info(f"Adding watcher {watcher_name}")
...@@ -184,7 +184,7 @@ class LocalConnector(PlannerConnector): ...@@ -184,7 +184,7 @@ class LocalConnector(PlannerConnector):
if blocking: if blocking:
required_endpoint_ids = pre_add_endpoint_ids + 1 required_endpoint_ids = pre_add_endpoint_ids + 1
while True: while True:
current_endpoint_ids = await self._get_endpoint_ids(component_name) current_endpoint_ids = await self._count_instance_ids(component_name)
if current_endpoint_ids == required_endpoint_ids: if current_endpoint_ids == required_endpoint_ids:
break break
logger.info( logger.info(
...@@ -248,9 +248,9 @@ class LocalConnector(PlannerConnector): ...@@ -248,9 +248,9 @@ class LocalConnector(PlannerConnector):
return success return success
async def _get_endpoint_ids(self, component_name: str) -> int: async def _count_instance_ids(self, component_name: str) -> int:
""" """
Get the endpoint IDs for a component. Count the instance IDs for the 'generate' endpoint of given component.
Args: Args:
component_name: Name of the component component_name: Name of the component
...@@ -266,7 +266,7 @@ class LocalConnector(PlannerConnector): ...@@ -266,7 +266,7 @@ class LocalConnector(PlannerConnector):
.endpoint("generate") .endpoint("generate")
.client() .client()
) )
worker_ids = self.worker_client.endpoint_ids() worker_ids = self.worker_client.instance_ids()
return len(worker_ids) return len(worker_ids)
elif component_name == "PrefillWorker": elif component_name == "PrefillWorker":
if self.prefill_client is None: if self.prefill_client is None:
...@@ -276,7 +276,7 @@ class LocalConnector(PlannerConnector): ...@@ -276,7 +276,7 @@ class LocalConnector(PlannerConnector):
.endpoint("mock") .endpoint("mock")
.client() .client()
) )
prefill_ids = self.prefill_client.endpoint_ids() prefill_ids = self.prefill_client.instance_ids()
return len(prefill_ids) return len(prefill_ids)
else: else:
raise ValueError(f"Component {component_name} not supported") raise ValueError(f"Component {component_name} not supported")
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* [Automatically download a model from Hugging Face](#use-model-from-hugging-face) * [Automatically download a model from Hugging Face](#use-model-from-hugging-face)
* [Run a model from local file](#run-a-model-from-local-file) * [Run a model from local file](#run-a-model-from-local-file)
* [Distributed system](#distributed-system) * [Distributed system](#distributed-system)
* [Network names](#network-names)
* [KV-aware routing](#kv-aware-routing) * [KV-aware routing](#kv-aware-routing)
* [Full usage details](#full-usage-details) * [Full usage details](#full-usage-details)
* [Setup](#setup) * [Setup](#setup)
...@@ -24,7 +25,7 @@ It supports the following engines: mistralrs, llamacpp, sglang, vllm and tensorr ...@@ -24,7 +25,7 @@ It supports the following engines: mistralrs, llamacpp, sglang, vllm and tensorr
Usage: Usage:
``` ```
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv]
``` ```
Example: `dynamo run Qwen/Qwen3-0.6B` Example: `dynamo run Qwen/Qwen3-0.6B`
...@@ -95,7 +96,7 @@ You will need [etcd](https://etcd.io/) and [nats](https://nats.io) with jetstrea ...@@ -95,7 +96,7 @@ You will need [etcd](https://etcd.io/) and [nats](https://nats.io) with jetstrea
OpenAI compliant HTTP server, optional pre-processing, worker discovery. OpenAI compliant HTTP server, optional pre-processing, worker discovery.
``` ```
dynamo-run in=http out=dyn://llama3B_pool dynamo-run in=http out=dyn
``` ```
**Node 2:** **Node 2:**
...@@ -103,17 +104,50 @@ dynamo-run in=http out=dyn://llama3B_pool ...@@ -103,17 +104,50 @@ dynamo-run in=http out=dyn://llama3B_pool
Vllm engine. Receives and returns requests over the network. Vllm engine. Receives and returns requests over the network.
``` ```
dynamo-run in=dyn://llama3B_pool out=vllm ~/llms/Llama-3.2-3B-Instruct dynamo-run in=dyn://llama3B.backend.generate out=vllm ~/llms/Llama-3.2-3B-Instruct
``` ```
This will use etcd to auto-discover the model and NATS to talk to it. You can This will use etcd to auto-discover the model and NATS to talk to it. You can
run multiple workers on the same endpoint and it will pick one based on the run multiple instances on the same endpoint and it will pick one based on the
`--router-mode` (round-robin by default if left unspecified). `--router-mode` (round-robin by default if left unspecified).
The `llama3B_pool` name is purely symbolic, pick anything as long as it matches the other node.
Run `dynamo-run --help` for more options. Run `dynamo-run --help` for more options.
### Network names
The `in=dyn://` URLs have the format `dyn://namespace.component.endpoint`. For quickstart just use any string `dyn://test`, `dynamo-run` will default any missing parts for you. The pieces matter for a larger system.
* *Namespace*: A pipeline. Usually a model. e.g "llama_8b". Just a name.
* *Component*: A load balanced service needed to run that pipeline. "backend", "prefill", "decode", "preprocessor", "draft", etc. This typically has some configuration (which model to use, for example).
* *Endpoint*: Like a URL. "generate", "load_metrics".
* *Instance*: A process. Unique. Dynamo assigns each one a unique instance_id. The thing that is running is always an instance. Namespace/component/endpoint can refer to multiple instances.
If you run two models, that is two pipelines. An exception would be if doing speculative decoding. The draft model is part of the pipeline of a bigger model.
If you run two instances of the same model ("data parallel") they are the same namespace+component+endpoint but different instances. The router will spread traffic over all the instances of a namespace+component+endpoint. If you have four prefill workers in a pipeline, they all have the same namespace+component+endpoint and are automatically assigned unique instance_ids.
Example 1: Data parallel load balanced, one model one pipeline two instances.
```
Node 1: dynamo-run in=dyn://qwen3-32b.backend.generate out=sglang /data/Qwen3-32B --tensor-parallel-size 2 --base-gpu-id 0
Node 2: dynamo-run in=dyn://qwen3-32b.backend.generate out=sglang /data/Qwen3-32B --tensor-parallel-size 2 --base-gpu-id 2
```
Example 2: Two models, two pipelines.
```
Node 1: dynamo-run in=dyn://qwen3-32b.backend.generate out=vllm /data/Qwen3-32B
Node 2: dynamo-run in=dyn://llama3-1-8b.backend.generate out=vllm /data/Llama-3.1-8B-Instruct/
```
Example 3: Different endpoints.
The KV metrics publisher in VLLM adds a `load_metrics` endpoint to the current component. If the `llama3-1-8b.backend` component above is using patched vllm it will also expose `llama3-1-8b.backend.load_metrics`.
Example 4: Multiple component in a pipeline
In the P/D disaggregated setup you would have `deepseek-distill-llama8b.prefill.generate` (possibly multiple instance of this) and `deepseek-distill-llama8b.decode.generate`.
For output it is always only `out=dyn`. This tells Dynamo to auto-discover the instances, group them by model, and load balance appropriately (depending on `--router-mode` flag). The old syntax of `dyn://...` is still accepted for backwards compatibility.
### KV-aware routing ### KV-aware routing
**Setup** **Setup**
...@@ -144,7 +178,7 @@ dynamo-run in=dyn://dynamo.endpoint.generate out=vllm /data/llms/Qwen/Qwen3-4B ...@@ -144,7 +178,7 @@ dynamo-run in=dyn://dynamo.endpoint.generate out=vllm /data/llms/Qwen/Qwen3-4B
**Start the ingress node** **Start the ingress node**
``` ```
dynamo-run in=http out=dyn://dynamo.endpoint.generate --router-mode kv dynamo-run in=http out=dyn --router-mode kv
``` ```
The only difference from the distributed system above is `--router-mode kv`. The patched vllm will announce when a KV block is created or removed. The Dynamo router run will find the worker with the best match for those KV blocks and direct the traffic to that node. The only difference from the distributed system above is `--router-mode kv`. The patched vllm will announce when a KV block is created or removed. The Dynamo router run will find the worker with the best match for those KV blocks and direct the traffic to that node.
......
...@@ -63,7 +63,7 @@ class Router: ...@@ -63,7 +63,7 @@ class Router:
print("KV Router initialized") print("KV Router initialized")
def _cost_function(self, request_prompt): def _cost_function(self, request_prompt):
worker_ids = self.workers_client.endpoint_ids() worker_ids = self.workers_client.instance_ids()
num_workers = len(worker_ids) num_workers = len(worker_ids)
max_hit_rate = -1.0 max_hit_rate = -1.0
for curr_id in self.kv_cache.keys(): for curr_id in self.kv_cache.keys():
......
...@@ -275,7 +275,7 @@ async def check_required_workers( ...@@ -275,7 +275,7 @@ async def check_required_workers(
tag="", tag="",
): ):
"""Wait until the minimum number of workers are ready.""" """Wait until the minimum number of workers are ready."""
worker_ids = workers_client.endpoint_ids() worker_ids = workers_client.instance_ids()
num_workers = len(worker_ids) num_workers = len(worker_ids)
new_count = -1 # Force to print "waiting for worker" once new_count = -1 # Force to print "waiting for worker" once
while num_workers < required_workers: while num_workers < required_workers:
...@@ -287,7 +287,7 @@ async def check_required_workers( ...@@ -287,7 +287,7 @@ async def check_required_workers(
f" Required: {required_workers}" f" Required: {required_workers}"
) )
await asyncio.sleep(poll_interval) await asyncio.sleep(poll_interval)
worker_ids = workers_client.endpoint_ids() worker_ids = workers_client.instance_ids()
new_count = len(worker_ids) new_count = len(worker_ids)
print(f"Workers ready: {worker_ids}") print(f"Workers ready: {worker_ids}")
......
...@@ -170,7 +170,7 @@ class Router: ...@@ -170,7 +170,7 @@ class Router:
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers # Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation # and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.endpoint_ids() worker_ids = self.workers_client.instance_ids()
worker_logits = {} worker_logits = {}
for worker_id in worker_ids: for worker_id in worker_ids:
......
...@@ -99,8 +99,8 @@ class Planner: ...@@ -99,8 +99,8 @@ class Planner:
) )
# TODO: remove this sleep after rust client() is blocking until watching state # TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# TODO: use etcd events instead of pulling endpoints_ids # TODO: use etcd events instead of pulling instance_ids
p_endpoints = self.prefill_client.endpoint_ids() p_endpoints = self.prefill_client.instance_ids()
except Exception: except Exception:
p_endpoints = [] p_endpoints = []
self._repeating_log_func( self._repeating_log_func(
...@@ -116,8 +116,8 @@ class Planner: ...@@ -116,8 +116,8 @@ class Planner:
) )
# TODO: remove this sleep after rust client() is blocking until watching state # TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# TODO: use etcd events instead of pulling endpoints_ids # TODO: use etcd events instead of pulling instance_ids
d_endpoints = self.workers_client.endpoint_ids() d_endpoints = self.workers_client.instance_ids()
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to get decode worker endpoints: {e}") raise RuntimeError(f"Failed to get decode worker endpoints: {e}")
return p_endpoints, d_endpoints return p_endpoints, d_endpoints
......
...@@ -25,12 +25,12 @@ async def check_required_workers( ...@@ -25,12 +25,12 @@ async def check_required_workers(
workers_client: Client, required_workers: int, on_change=True, poll_interval=0.5 workers_client: Client, required_workers: int, on_change=True, poll_interval=0.5
): ):
"""Wait until the minimum number of workers are ready.""" """Wait until the minimum number of workers are ready."""
worker_ids = workers_client.endpoint_ids() worker_ids = workers_client.instance_ids()
num_workers = len(worker_ids) num_workers = len(worker_ids)
while num_workers < required_workers: while num_workers < required_workers:
await asyncio.sleep(poll_interval) await asyncio.sleep(poll_interval)
worker_ids = workers_client.endpoint_ids() worker_ids = workers_client.instance_ids()
new_count = len(worker_ids) new_count = len(worker_ids)
if (not on_change) or new_count != num_workers: if (not on_change) or new_count != num_workers:
......
...@@ -25,12 +25,12 @@ async def check_required_workers( ...@@ -25,12 +25,12 @@ async def check_required_workers(
workers_client: Client, required_workers: int, on_change=True, poll_interval=0.5 workers_client: Client, required_workers: int, on_change=True, poll_interval=0.5
): ):
"""Wait until the minimum number of workers are ready.""" """Wait until the minimum number of workers are ready."""
worker_ids = workers_client.endpoint_ids() worker_ids = workers_client.instance_ids()
num_workers = len(worker_ids) num_workers = len(worker_ids)
while num_workers < required_workers: while num_workers < required_workers:
await asyncio.sleep(poll_interval) await asyncio.sleep(poll_interval)
worker_ids = workers_client.endpoint_ids() worker_ids = workers_client.instance_ids()
new_count = len(worker_ids) new_count = len(worker_ids)
if (not on_change) or new_count != num_workers: if (not on_change) or new_count != num_workers:
......
...@@ -90,10 +90,10 @@ class Router: ...@@ -90,10 +90,10 @@ class Router:
.endpoint("generate") .endpoint("generate")
.client() .client()
) )
while len(self.workers_client.endpoint_ids()) < self.args.min_workers: while len(self.workers_client.instance_ids()) < self.args.min_workers:
logger.info( logger.info(
f"Waiting for more workers to be ready.\n" f"Waiting for more workers to be ready.\n"
f" Current: {len(self.workers_client.endpoint_ids())}," f" Current: {len(self.workers_client.instance_ids())},"
f" Required: {self.args.min_workers}" f" Required: {self.args.min_workers}"
) )
await asyncio.sleep(30) await asyncio.sleep(30)
...@@ -144,7 +144,7 @@ class Router: ...@@ -144,7 +144,7 @@ class Router:
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers # Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation # and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.endpoint_ids() worker_ids = self.workers_client.instance_ids()
worker_logits = {} worker_logits = {}
for worker_id in worker_ids: for worker_id in worker_ids:
......
...@@ -78,10 +78,10 @@ class Processor(ChatProcessorMixin): ...@@ -78,10 +78,10 @@ class Processor(ChatProcessorMixin):
.client() .client()
) )
while len(self.worker_client.endpoint_ids()) < self.min_workers: while len(self.worker_client.instance_ids()) < self.min_workers:
logger.info( logger.info(
f"Waiting for workers to be ready.\n" f"Waiting for workers to be ready.\n"
f" Current: {len(self.worker_client.endpoint_ids())}," f" Current: {len(self.worker_client.instance_ids())},"
f" Required: {self.min_workers}" f" Required: {self.min_workers}"
) )
await asyncio.sleep(30) await asyncio.sleep(30)
......
...@@ -71,10 +71,10 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine): ...@@ -71,10 +71,10 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine):
.endpoint("generate") .endpoint("generate")
.client() .client()
) )
while len(self._prefill_client.endpoint_ids()) < self._min_prefill_workers: while len(self._prefill_client.instance_ids()) < self._min_prefill_workers:
logger.info( logger.info(
f"Waiting for prefill workers to be ready.\n" f"Waiting for prefill workers to be ready.\n"
f" Current: {len(self._prefill_client.endpoint_ids())}," f" Current: {len(self._prefill_client.instance_ids())},"
f" Required: {self._min_prefill_workers}" f" Required: {self._min_prefill_workers}"
) )
await asyncio.sleep(30) await asyncio.sleep(30)
......
...@@ -98,7 +98,7 @@ pub struct Flags { ...@@ -98,7 +98,7 @@ pub struct Flags {
#[arg(long)] #[arg(long)]
pub leader_addr: Option<String>, pub leader_addr: Option<String>,
/// If using `out=dyn://..` with multiple backends, this says how to route the requests. /// If using `out=dyn` with multiple instances, this says how to route the requests.
/// ///
/// Mostly interesting for KV-aware routing. /// Mostly interesting for KV-aware routing.
/// Defaults to RouterMode::RoundRobin /// Defaults to RouterMode::RoundRobin
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context as _; use anyhow::Context as _;
use async_openai::types::FinishReason; use async_openai::types::FinishReason;
...@@ -65,7 +53,7 @@ struct Entry { ...@@ -65,7 +53,7 @@ struct Entry {
pub async fn run( pub async fn run(
runtime: Runtime, runtime: Runtime,
flags: Flags, _flags: Flags,
card: ModelDeploymentCard, card: ModelDeploymentCard,
input_jsonl: PathBuf, input_jsonl: PathBuf,
engine_config: EngineConfig, engine_config: EngineConfig,
...@@ -80,7 +68,7 @@ pub async fn run( ...@@ -80,7 +68,7 @@ pub async fn run(
); );
} }
let prepared_engine = common::prepare_engine(runtime, flags, engine_config).await?; let prepared_engine = common::prepare_engine(runtime, engine_config).await?;
let service_name_ref = Arc::new(prepared_engine.service_name); let service_name_ref = Arc::new(prepared_engine.service_name);
let pre_processor = if card.has_tokenizer() { let pre_processor = if card.has_tokenizer() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::pin::Pin; use std::pin::Pin;
use dynamo_llm::{ use dynamo_llm::{
backend::{Backend, ExecutionContext}, backend::{Backend, ExecutionContext},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
http::service::discovery::ModelNetworkName, http::service::{discovery::ModelWatcher, ModelManager},
kv_router::{scheduler::DefaultWorkerSelector, KvPushRouter, KvRouter},
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
model_type::ModelType,
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
protocols::common::llm_backend::{BackendInput, BackendOutput, LLMEngineOutput}, protocols::common::llm_backend::{BackendInput, BackendOutput},
types::{ types::{
openai::chat_completions::{ openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
...@@ -33,139 +19,63 @@ use dynamo_llm::{ ...@@ -33,139 +19,63 @@ use dynamo_llm::{
}, },
}; };
use dynamo_runtime::{ use dynamo_runtime::{
component,
engine::{AsyncEngineStream, Data}, engine::{AsyncEngineStream, Data},
pipeline::{ pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
Context, ManyOut, Operator, PushRouter, SegmentSource, ServiceBackend, ServiceFrontend,
SingleIn, Source,
},
DistributedRuntime, Runtime, DistributedRuntime, Runtime,
}; };
use std::sync::Arc; use std::sync::Arc;
use crate::{flags::RouterMode, EngineConfig, Flags}; use crate::EngineConfig;
pub struct PreparedEngine { pub struct PreparedEngine {
pub service_name: String, pub service_name: String,
pub engine: OpenAIChatCompletionsStreamingEngine, pub engine: OpenAIChatCompletionsStreamingEngine,
pub inspect_template: bool, pub inspect_template: bool,
pub _cache_dir: Option<tempfile::TempDir>,
} }
/// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine. /// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine.
pub async fn prepare_engine( pub async fn prepare_engine(
runtime: Runtime, runtime: Runtime,
flags: Flags,
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<PreparedEngine> { ) -> anyhow::Result<PreparedEngine> {
match engine_config { match engine_config {
EngineConfig::Dynamic(endpoint_id) => { EngineConfig::Dynamic => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let endpoint = distributed_runtime
.namespace(endpoint_id.namespace.clone())?
.component(endpoint_id.component.clone())?
.endpoint(endpoint_id.name.clone());
let client = endpoint.client().await?;
let mut cache_dir = None;
tracing::info!("Waiting for remote model..");
let remote_endpoints = client.wait_for_endpoints().await?;
debug_assert!(!remote_endpoints.is_empty());
tracing::info!(count = remote_endpoints.len(), "Model(s) discovered");
let network_name: ModelNetworkName = (&remote_endpoints[0]).into();
let Some(etcd_client) = distributed_runtime.etcd_client() else { let Some(etcd_client) = distributed_runtime.etcd_client() else {
anyhow::bail!("Cannot run distributed components without etcd"); anyhow::bail!("Cannot be both static mode and run with dynamic discovery.");
}; };
let network_entry = network_name.load_entry(&etcd_client).await?; let model_manager = ModelManager::new();
let mut card = network_entry.load_mdc(endpoint_id, &etcd_client).await?; let watch_obj = Arc::new(
ModelWatcher::new(
let engine: OpenAIChatCompletionsStreamingEngine = match network_entry.model_type { distributed_runtime,
ModelType::Backend => { model_manager.clone(),
// Download tokenizer.json etc to local disk dynamo_runtime::pipeline::RouterMode::RoundRobin,
cache_dir = Some(
card.move_from_nats(distributed_runtime.nats_client())
.await?,
);
// The backend doesn't mind what we expose to the user (chat or
// completions), and this function is only used by text and batch input so
// the user doesn't see the HTTP request. So use Chat.
let frontend = SegmentSource::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router =
PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
client,
flags.router_mode.into(),
) )
.await?;
let service_backend = match &flags.router_mode {
RouterMode::Random | RouterMode::RoundRobin => {
ServiceBackend::from_engine(Arc::new(router))
}
RouterMode::KV => {
let selector = Box::new(DefaultWorkerSelector {});
let chooser = KvRouter::new(
endpoint.component().clone(),
dynamo_llm::DEFAULT_KV_BLOCK_SIZE,
Some(selector),
)
.await?;
let kv_push_router = KvPushRouter::new(router, Arc::new(chooser));
ServiceBackend::from_engine(Arc::new(kv_push_router))
}
};
frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(service_backend)?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?
}
ModelType::Chat => Arc::new(
PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(client, flags.router_mode.into())
.await?, .await?,
),
ModelType::Completion => {
anyhow::bail!(
"text and batch input only accept remote Chat models, not Completion"
);
/*
Arc::new(
PushRouter::<
CompletionRequest,
Annotated<CompletionResponse>,
>::from_client(
client, flags.router_mode.into()
)
.await?,
)
*/
}
ModelType::Embedding => {
anyhow::bail!(
"text and batch input only accept remote Chat models, not Embedding"
); );
} let models_watcher = etcd_client
}; .kv_get_and_watch_prefix(component::MODEL_ROOT_PATH)
// The service_name isn't used for text chat outside of logs, .await?;
// so use the path. That avoids having to listen on etcd for model registration. let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let service_name = endpoint.subject();
let inner_watch_obj = watch_obj.clone();
let _watcher_task = tokio::spawn(inner_watch_obj.watch(receiver));
tracing::info!("Waiting for remote model..");
// TODO: We use the first model to appear, usually we have only one
// We should add slash commands to text input `/model <name>` to choose,
// '/models` to list, and notifications when models are added / removed.
let model_service_name = watch_obj.wait_for_chat_model().await;
let engine = model_manager
.state()
.get_chat_completions_engine(&model_service_name)?;
Ok(PreparedEngine { Ok(PreparedEngine {
service_name, service_name: model_service_name,
engine, engine,
inspect_template: false, inspect_template: false,
_cache_dir: cache_dir,
}) })
} }
EngineConfig::StaticFull { engine, model } => { EngineConfig::StaticFull { engine, model } => {
...@@ -176,7 +86,6 @@ pub async fn prepare_engine( ...@@ -176,7 +86,6 @@ pub async fn prepare_engine(
service_name, service_name,
engine, engine,
inspect_template: false, inspect_template: false,
_cache_dir: None,
}) })
} }
EngineConfig::StaticCore { EngineConfig::StaticCore {
...@@ -195,7 +104,6 @@ pub async fn prepare_engine( ...@@ -195,7 +104,6 @@ pub async fn prepare_engine(
service_name, service_name,
engine: pipeline, engine: pipeline,
inspect_template: true, inspect_template: true,
_cache_dir: None,
}) })
} }
} }
......
...@@ -89,7 +89,7 @@ pub async fn run( ...@@ -89,7 +89,7 @@ pub async fn run(
(Box::pin(fut), Some(model.card().clone())) (Box::pin(fut), Some(model.card().clone()))
} }
EngineConfig::Dynamic(_) => { EngineConfig::Dynamic => {
// We can only get here for in=dyn out=vllm|sglang`, because vllm and sglang are a // We can only get here for in=dyn out=vllm|sglang`, because vllm and sglang are a
// subprocess that we talk to like a remote endpoint. // subprocess that we talk to like a remote endpoint.
// That means the vllm/sglang subprocess is doing all the work, we are idle. // That means the vllm/sglang subprocess is doing all the work, we are idle.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc; use std::sync::Arc;
...@@ -29,7 +17,7 @@ use dynamo_llm::{ ...@@ -29,7 +17,7 @@ use dynamo_llm::{
openai::completions::{CompletionRequest, CompletionResponse}, openai::completions::{CompletionRequest, CompletionResponse},
}, },
}; };
use dynamo_runtime::component::Component; use dynamo_runtime::component;
use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::transports::etcd; use dynamo_runtime::transports::etcd;
use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{DistributedRuntime, Runtime};
...@@ -48,23 +36,16 @@ pub async fn run( ...@@ -48,23 +36,16 @@ pub async fn run(
.with_request_template(template) .with_request_template(template)
.build()?; .build()?;
match engine_config { match engine_config {
EngineConfig::Dynamic(endpoint) => { EngineConfig::Dynamic => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
match distributed_runtime.etcd_client() { match distributed_runtime.etcd_client() {
Some(etcd_client) => { Some(etcd_client) => {
// This will attempt to connect to NATS and etcd
let component = distributed_runtime
.namespace(endpoint.namespace)?
.component(endpoint.component)?;
let network_prefix = component.service_name();
// Listen for models registering themselves in etcd, add them to HTTP service // Listen for models registering themselves in etcd, add them to HTTP service
run_watcher( run_watcher(
component.clone(), distributed_runtime,
http_service.model_manager().clone(), http_service.model_manager().clone(),
etcd_client.clone(), etcd_client.clone(),
&network_prefix, component::MODEL_ROOT_PATH,
flags.router_mode.into(), flags.router_mode.into(),
) )
.await?; .await?;
...@@ -117,15 +98,14 @@ pub async fn run( ...@@ -117,15 +98,14 @@ pub async fn run(
/// Spawns a task that watches for new models in etcd at network_prefix, /// Spawns a task that watches for new models in etcd at network_prefix,
/// and registers them with the ModelManager so that the HTTP service can use them. /// and registers them with the ModelManager so that the HTTP service can use them.
async fn run_watcher( async fn run_watcher(
component: Component, runtime: DistributedRuntime,
model_manager: ModelManager, model_manager: ModelManager,
etcd_client: etcd::Client, etcd_client: etcd::Client,
network_prefix: &str, network_prefix: &str,
router_mode: RouterMode, router_mode: RouterMode,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let watch_obj = Arc::new( let watch_obj =
discovery::ModelWatcher::new(component, model_manager, network_prefix, router_mode).await?, Arc::new(discovery::ModelWatcher::new(runtime, model_manager, router_mode).await?);
);
tracing::info!("Watching for remote model at {network_prefix}"); tracing::info!("Watching for remote model at {network_prefix}");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?; let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use dynamo_llm::protocols::openai::nvext::NvExt; use dynamo_llm::protocols::openai::nvext::NvExt;
use dynamo_llm::types::openai::chat_completions::{ use dynamo_llm::types::openai::chat_completions::{
...@@ -30,13 +18,13 @@ const MAX_TOKENS: u32 = 8192; ...@@ -30,13 +18,13 @@ const MAX_TOKENS: u32 = 8192;
pub async fn run( pub async fn run(
runtime: Runtime, runtime: Runtime,
flags: Flags, _flags: Flags,
single_prompt: Option<String>, single_prompt: Option<String>,
engine_config: EngineConfig, engine_config: EngineConfig,
template: Option<RequestTemplate>, template: Option<RequestTemplate>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token(); let cancel_token = runtime.primary_token();
let prepared_engine = common::prepare_engine(runtime, flags, engine_config).await?; let prepared_engine = common::prepare_engine(runtime, engine_config).await?;
main_loop( main_loop(
cancel_token, cancel_token,
&prepared_engine.service_name, &prepared_engine.service_name,
......
...@@ -6,7 +6,7 @@ use std::{io::Read, sync::Arc, time::Duration}; ...@@ -6,7 +6,7 @@ use std::{io::Read, sync::Arc, time::Duration};
use anyhow::Context; use anyhow::Context;
use dynamo_llm::{backend::ExecutionContext, engines::StreamingEngine, LocalModel}; use dynamo_llm::{backend::ExecutionContext, engines::StreamingEngine, LocalModel};
use dynamo_runtime::{protocols::Endpoint, CancellationToken, DistributedRuntime}; use dynamo_runtime::{CancellationToken, DistributedRuntime};
mod flags; mod flags;
pub use flags::Flags; pub use flags::Flags;
...@@ -26,8 +26,8 @@ const PYTHON_STR_SCHEME: &str = "pystr:"; ...@@ -26,8 +26,8 @@ const PYTHON_STR_SCHEME: &str = "pystr:";
pub const INTERNAL_ENDPOINT: &str = "dyn://dynamo.internal.worker"; pub const INTERNAL_ENDPOINT: &str = "dyn://dynamo.internal.worker";
pub enum EngineConfig { pub enum EngineConfig {
/// An remote networked engine we don't know about yet /// Remote networked engines
Dynamic(Endpoint), Dynamic,
/// A Full service engine does it's own tokenization and prompt formatting. /// A Full service engine does it's own tokenization and prompt formatting.
StaticFull { StaticFull {
...@@ -48,7 +48,7 @@ pub async fn run( ...@@ -48,7 +48,7 @@ pub async fn run(
out_opt: Output, out_opt: Output,
flags: Flags, flags: Flags,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
if matches!(&in_opt, Input::Endpoint(_)) && matches!(&out_opt, Output::Endpoint(_)) { if matches!(&in_opt, Input::Endpoint(_)) && matches!(&out_opt, Output::Dynamic) {
anyhow::bail!("Cannot use endpoint for both in and out"); anyhow::bail!("Cannot use endpoint for both in and out");
} }
...@@ -59,9 +59,9 @@ pub async fn run( ...@@ -59,9 +59,9 @@ pub async fn run(
.or(flags.model_path_flag.clone()); .or(flags.model_path_flag.clone());
let local_model: LocalModel = match out_opt { let local_model: LocalModel = match out_opt {
// If output is an endpoint we are ingress and don't have a local model, but making an // If output is dynamic we are ingress and don't have a local model, but making an
// empty one cleans up the code. // empty one cleans up the code.
Output::Endpoint(_) => Default::default(), Output::Dynamic => Default::default(),
// All other output types have a local model // All other output types have a local model
_ => { _ => {
...@@ -100,10 +100,7 @@ pub async fn run( ...@@ -100,10 +100,7 @@ pub async fn run(
// Create the engine matching `out` // Create the engine matching `out`
let engine_config = match out_opt { let engine_config = match out_opt {
Output::Endpoint(path) => { Output::Dynamic => EngineConfig::Dynamic,
let endpoint: Endpoint = path.parse()?;
EngineConfig::Dynamic(endpoint)
}
Output::EchoFull => EngineConfig::StaticFull { Output::EchoFull => EngineConfig::StaticFull {
model: Box::new(local_model), model: Box::new(local_model),
engine: dynamo_llm::engines::make_engine_full(), engine: dynamo_llm::engines::make_engine_full(),
...@@ -173,7 +170,7 @@ pub async fn run( ...@@ -173,7 +170,7 @@ pub async fn run(
extra = Some(Box::pin(async move { extra = Some(Box::pin(async move {
stopper(cancel_token, child, py_script).await; stopper(cancel_token, child, py_script).await;
})); }));
EngineConfig::Dynamic(endpoint) EngineConfig::Dynamic
} }
Output::Vllm => { Output::Vllm => {
if flags.base_gpu_id != 0 { if flags.base_gpu_id != 0 {
...@@ -209,7 +206,7 @@ pub async fn run( ...@@ -209,7 +206,7 @@ pub async fn run(
extra = Some(Box::pin(async move { extra = Some(Box::pin(async move {
stopper(cancel_token, child, py_script).await; stopper(cancel_token, child, py_script).await;
})); }));
EngineConfig::Dynamic(endpoint) EngineConfig::Dynamic
} }
#[cfg(feature = "llamacpp")] #[cfg(feature = "llamacpp")]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment