dynamo_run.md 28.3 KB
Newer Older
atchernych's avatar
atchernych committed
1
# Running Dynamo CLI (`dynamo-run`)
2

3

atchernych's avatar
atchernych committed
4
5
With the Dynamo CLI, you can chat with models quickly using `dynamo-run`
`dynamo-run` is a CLI tool for exploring the Dynamo components. It's also an example of how to use components from Rust. If you use the Python wheel, it's available as `dynamo-run`.
6
7

It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm. `mistralrs` is the default.
8
9
10

Usage:
```
11
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.0] [--use-kv-events=true] [--verbosity (-v|-vv)]
12
13
```

atchernych's avatar
atchernych committed
14
Example: `dynamo-run Qwen/Qwen3-0.6B`
15

16
Set the environment variable `DYN_LOG` to adjust the logging level; for example, `export DYN_LOG=debug`. It has the same syntax as `RUST_LOG`.
17

18
19
20
21
22
23
24
To adjust verbosity, use `-v` to enable debug logging or `-vv` to enable full trace logging. For example:

```bash
dynamo-run in=http out=mistralrs -v  # enables debug logging
dynamo-run in=text out=llamacpp -vv  # enables full trace logging
```

25
## Quickstart with pip and vllm
26

27
If you used `pip` to install `dynamo`, you have the `dynamo-run` binary pre-installed with the `vllm` engine. You must be in a virtual environment with vllm installed to use this engine. To compile from source, see [Full usage details](#full-usage-details) below.
28

29
30
The vllm and sglang engines require [etcd](https://etcd.io/) and [nats](https://nats.io/) with jetstream (`nats-server -js`). Mistralrs and llamacpp do not.

31
### Use model from Hugging Face
32

33
To automatically download Qwen3 4B from Hugging Face (16 GiB download) and to start it in interactive text mode:
34
```
atchernych's avatar
atchernych committed
35
dynamo-run out=vllm Qwen/Qwen3-4B
36
37
```

38
The general format for HF download follows this pattern:
39
```
atchernych's avatar
atchernych committed
40
dynamo-run out=<engine> <HUGGING_FACE_ORGANIZATION/MODEL_NAME>
41
42
```

43
For gated models (such as meta-llama/Llama-3.2-3B-Instruct), you must set an `HF_TOKEN` environment variable.
44

45
The parameter can be the ID of a HuggingFace repository (which will be downloaded), a GPT-Generated Unified Format (GGUF) file, or a folder containing safetensors, config.json, or similar (perhaps a locally checked out HuggingFace repository).
46

47
### Run a model from local file
48

49
50
51
52
53
To run a model from local file:
- Download the model from Hugging Face
- Run the model from local file

See the following sections for details.
54

55
#### Download model from Hugging Face
56
One of the models available from Hugging Face should be high quality and fast on almost any machine: https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF
57
58
59
For example, try https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/blob/main/Llama-3.2-3B-Instruct-Q4_K_M.gguf

To download model file:
60
```
61
curl -L -o Llama-3.2-3B-Instruct-Q4_K_M.gguf "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_K_M.gguf?download=true"
62
```
63
#### Run model from local file
64
65
66
To run the model:

*Text interface*
67
```
atchernych's avatar
atchernych committed
68
dynamo-run Llama-3.2-3B-Instruct-Q4_K_M.gguf # or path to a Hugging Face repo checkout instead of the GGUF file
69
```
70

71
*HTTP interface*
72
```
atchernych's avatar
atchernych committed
73
dynamo-run in=http out=mistralrs Llama-3.2-3B-Instruct-Q4_K_M.gguf
74
```
75
You can also list models or send a request:
76

77
*List the models*
78
```
79
curl localhost:8080/v1/models
80
```
81

82
*Send a request*
83
```
84
curl -d '{"model": "Llama-3.2-3B-Instruct-Q4_K_M", "max_completion_tokens": 2049, "messages":[{"role":"user", "content": "What is the capital of South Africa?" }]}' -H 'Content-Type: application/json' http://localhost:8080/v1/chat/completions
85
```
86

87
88
89
### Distributed System

You can run the ingress side (HTTP server and pre-processing) on one machine, for example a CPU node, and the worker on a different machine (a GPU node).
90

91
You will need [etcd](https://etcd.io/) and [nats](https://nats.io) with jetstream installed and accessible from both nodes.
92

93
**Node 1:** OpenAI compliant HTTP server, optional pre-processing, worker discovery:
94

95
```
96
dynamo-run in=http out=auto
97
```
98

99
**Node 2:** Vllm engine. Receives and returns requests over the network:
100

101
```
102
dynamo-run in=dyn://llama3B.backend.generate out=vllm ~/llms/Llama-3.2-3B-Instruct
103
```
104

105
106
This uses etcd to auto-discover the model and NATS to talk to it. You can
run multiple instances on the same endpoint; it picks one based on the
107
`--router-mode` (round-robin by default if left unspecified).
108

109
Run `dynamo-run --help` for more options.
110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
### Network names

The `in=dyn://` URLs have the format `dyn://namespace.component.endpoint`. For quickstart just use any string `dyn://test`, `dynamo-run` will default any missing parts for you. The pieces matter for a larger system.

* *Namespace*: A pipeline. Usually a model. e.g "llama_8b". Just a name.
* *Component*: A load balanced service needed to run that pipeline. "backend", "prefill", "decode", "preprocessor", "draft", etc. This typically has some configuration (which model to use, for example).
* *Endpoint*: Like a URL. "generate", "load_metrics".
* *Instance*: A process. Unique. Dynamo assigns each one a unique instance_id. The thing that is running is always an instance. Namespace/component/endpoint can refer to multiple instances.

If you run two models, that is two pipelines. An exception would be if doing speculative decoding. The draft model is part of the pipeline of a bigger model.

If you run two instances of the same model ("data parallel") they are the same namespace+component+endpoint but different instances. The router will spread traffic over all the instances of a namespace+component+endpoint. If you have four prefill workers in a pipeline, they all have the same namespace+component+endpoint and are automatically assigned unique instance_ids.

Example 1: Data parallel load balanced, one model one pipeline two instances.
```
Node 1: dynamo-run in=dyn://qwen3-32b.backend.generate out=sglang /data/Qwen3-32B --tensor-parallel-size 2 --base-gpu-id 0
Node 2: dynamo-run in=dyn://qwen3-32b.backend.generate out=sglang /data/Qwen3-32B --tensor-parallel-size 2 --base-gpu-id 2
```

Example 2: Two models, two pipelines.
```
Node 1: dynamo-run in=dyn://qwen3-32b.backend.generate out=vllm /data/Qwen3-32B
Node 2: dynamo-run in=dyn://llama3-1-8b.backend.generate out=vllm /data/Llama-3.1-8B-Instruct/
```

Example 3: Different endpoints.

The KV metrics publisher in VLLM adds a `load_metrics` endpoint to the current component. If the `llama3-1-8b.backend` component above is using patched vllm it will also expose `llama3-1-8b.backend.load_metrics`.

140
Example 4: Multiple component in a pipeline.
141

142
In the P/D disaggregated setup you would have `deepseek-distill-llama8b.prefill.generate` (possibly multiple instances of this) and `deepseek-distill-llama8b.decode.generate`.
143

144
145
146
147
148
149
150
151
152
153
154
155
For output it is always only `out=auto`. This tells Dynamo to auto-discover the instances, group them by model, and load balance appropriately (depending on `--router-mode` flag). The exception is static workers, see that section.

### Static workers without etcd

Normally in the distributed system the frontend uses etcd to discover workers. The option exists to have a static endpoint without etcd.

```
Node 1: dynamo-run in=http out=dyn://dynamo.backend.generate --model-name Qwen3-0.6B-Q8_0.gguf --model-path ~/llms/Qwen3-0.6B
Node 2: dynamo-run in=dyn://dynamo.backend.generate out=llamacpp ~/llms/Qwen3-0.6B-Q8_0.gguf --static-worker --context-length 4096
```

Note how `out=` points to a single endpoint, which must match the worker. The model's name and config (to do pre-processing) are usually discovered by the frontend via etcd. Now we must pass them in (`--model-name` and `--model-path`).
156

157
158
159
160
### KV-aware routing

**Setup**

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
Currently, only patched vllm supports KV-aware routing.

To set up KV-aware routing on patched vllm:

1. Ensure that `etcd` and `nats` (see [Quickstart with pip and vllm](#quickstart-with-pip-and-vllm)) are running and accessible from all nodes.
1. Create a virtualenv: `uv venv kvtest` and source its `activate`.
1. Use `pip` to **either**:
   1. Install Dynamo's vllm branch:
      ```
      uv pip install ai-dynamo-vllm
      ```
       **or**
   1. Install upstream vllm 0.8.4:
      ```
      uv pip install vllm==0.8.4
      ```
      And then patch it:
      ```
      cd kvtest/lib/python3.12/site-packages
      patch -p1 < $REPO_ROOT/container/deps/vllm/vllm_v0.8.4-dynamo-kv-disagg-patch.patch
      ```
1. Build the C bindings:
   ```
   cd $REPO_ROOT/lib/bindings/c
185
   cargo build
186
187
188
189
190
191
   ```
1. Put the library you just built on library path:
   ```
   export LD_LIBRARY_PATH=$REPO_ROOT/target/debug/
   ```
If you patched locally (instead of installing `ai-dynamo-vllm`), edit vllm's `platforms/__init__.py` to undo a patch change:
192
193
194
195
196
197
198
```
    #vllm_version = version("ai_dynamo_vllm")
    vllm_version = version("vllm")
```

**Start the workers**

199
The workers are started normally:
200
201
202
203
204
205
206
207

```
dynamo-run in=dyn://dynamo.endpoint.generate out=vllm /data/llms/Qwen/Qwen3-4B
```

**Start the ingress node**

```
208
dynamo-run in=http out=auto --router-mode kv
209
210
```

211
The only difference from the distributed system above is `--router-mode kv`. The patched vllm announces when a KV block is created or removed. The Dynamo router run finds the worker with the best match for those KV blocks and directs the traffic to that node.
212

213
For performance testing, compare a typical workload with `--router-mode random|round-robin` to see if it can benefit from KV-aware routing.
214

215
216
217
218
219
220
221
The KV-aware routing arguments:

- `--kv-overlap-score-weight`: Sets the amount of weighting on overlaps with prefix caches, which directly contributes to the prefill cost. A large weight is expected to yield a better TTFT (at the expense of worse ITL). When set to 0, prefix caches are not considered at all (falling back to pure load balancing behavior on the active blocks).

- `--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 recovers the deterministic behavior where the min logit is picked.

- `--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events.
222

223
224
### Request Migration

225
In a [Distributed System](#distributed-system), you can enable [request migration](../architecture/request_migration.md) to handle worker failures gracefully. Use the `--migration-limit` flag to specify how many times a request can be migrated to another worker:
226
227

```bash
228
dynamo-run in=dyn://... out=<engine> ... --migration-limit=3
229
230
```

231
This allows a request to be migrated up to 3 times before failing. See the [Request Migration Architecture](../architecture/request_migration.md) documentation for details on how this works.
232

233
## Full usage details
234

atchernych's avatar
atchernych committed
235
 The `dynamo-run` is also an example of what can be built in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide shows how to build from source with all the features.
236

237
### Getting Started
238

239
240
241
#### Setup

##### Step 1: Install libraries
242
**Ubuntu:**
243
```
244
sudo apt install -y build-essential libhwloc-dev libudev-dev pkg-config libssl-dev libclang-dev protobuf-compiler python3-dev cmake
245
246
```

247
**macOS:**
248
249
250
251
252
253
254
- [Homebrew](https://brew.sh/)
```
# if brew is not installed on your system, install it
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
```
- [Xcode](https://developer.apple.com/xcode/)

255
256
```
brew install cmake protobuf
257

258
## Check that Metal is accessible
259
260
xcrun -sdk macosx metal
```
261
If Metal is accessible, you should see an error like `metal: error: no input files`, which confirms it is installed correctly.
262

263
##### Step 2: Install Rust
264
```
265
266
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
source $HOME/.cargo/env
267
```
268

269
##### Step 3: Build
270

271
- Linux with GPU and CUDA (tested on Ubuntu):
272
```
273
cargo build --features cuda
274
```
275

276
- macOS with Metal:
277
```
278
cargo build --features metal
279
280
```

281
- CPU only:
282
```
283
cargo build
284
285
```

286
287
288
Optionally you can run `cargo build` from any location with arguments:

```
289
290
--target-dir /path/to/target_directory # specify target_directory with write privileges
--manifest-path /path/to/project/Cargo.toml # if cargo build is run outside of `launch/` directory
291
292
```

293
The binary is called `dynamo-run` in `target/debug`
294
```
295
cd target/debug
296
297
```

298
299
Build with `--release` for a smaller binary and better performance, but longer build times. The binary will be in `target/release`.

300
301
302
303
304
305
306


#### Defaults
The input defaults to `in=text`. The output defaults to `out=mistralrs` engine, unless it is disabled with `--no-default-features` in which case vllm is used.
### Running Inference with Pre-built Engines

#### mistralrs
307
308
309
310

[mistral.rs](https://github.com/EricLBuehler/mistral.rs) is a pure Rust engine that is fast to run, fast to load, supports GGUF as well as safetensors, and runs well on CPU as well as GPU. For those reasons it is the default engine.

```
311
dynamo-run Qwen/Qwen3-4B
312
313
314
315
316
```

is equivalent to

```
317
dynamo-run in=text out=mistralrs Qwen/Qwen3-4B
318
319
```

320
321
If you have multiple GPUs, mistral.rs does automatic tensor parallelism. You do not need to pass any extra flags to dynamo-run to enable it.

322
#### llamacpp
323

324
[llama.cpp](https://github.com/ggml-org/llama.cpp) is built for CPU by default. For an optimized build pass the appropriate feature flag (highly recommended):
325

326
```
327
cargo build --features cuda|metal|vulkan -p dynamo-run
328
329
```

330
331
332
333
334
335
For GNU OpenMP support add the `openmp` feature. On Ubuntu this requires `libgomp1` (part of `build-essential`) at build and runtime.

```
cargo build --features cuda,openmp -p dynamo-run
```

336
```
337
338
dynamo-run out=llamacpp ~/llms/gemma-3-1b-it-q4_0.gguf
dynamo-run out=llamacpp ~/llms/Qwen3-0.6B-Q8_0.gguf # From https://huggingface.co/ggml-org
339
```
340

341
Note that in some cases we are unable to extract the tokenizer from the GGUF, and so a Hugging Face checkout of a matching model must also be passed. Dynamo uses the weights from the GGUF and the pre-processor (`tokenizer.json`, etc) from the `--model-config`:
342
```
343
dynamo-run out=llamacpp ~/llms/Llama-4-Scout-17B-16E-Instruct-UD-IQ1_S.gguf --context-length 32768 --model-config ~/llms/Llama-4-Scout-17B-16E-Instruct
344
345
346
```

If you have multiple GPUs, llama.cpp does automatic tensor parallelism. You do not need to pass any extra flags to dynamo-run to enable it.
347

348
#### sglang
349

350
351
The [SGLang](https://docs.sglang.ai/index.html) engine requires [etcd](https://etcd.io/) and [nats](https://nats.io/) with jetstream (`nats-server -js`) to be running.

352
353
1. Setup the python virtual env:

354
355
356
357
358
359
360
```
uv venv
source .venv/bin/activate
uv pip install pip
uv pip install sgl-kernel --force-reinstall --no-deps
uv pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
```
361

362
363
2. Run

364
Any example above using `out=sglang` can work, but our sglang backend is also multi-gpu.
365
366

```
367
368
cd target/debug
./dynamo-run in=http out=sglang --model-path ~/llms/DeepSeek-R1-Distill-Llama-70B/ --tensor-parallel-size 8
369
370
```

371
To pass extra arguments to the sglang engine see [Extra engine arguments](#extra-engine-arguments).
372

373
**Multi-GPU**
374

375
376
377
378
379
Pass `--tensor-parallel-size <NUM-GPUS>` to `dynamo-run`.

```
dynamo-run out=sglang ~/llms/Llama-4-Scout-17B-16E-Instruct/ --tensor-parallel-size 8
```
380

381
To specify the GPU to start from, pass `--base-gpu-id <num>`; for example, on a shared eight GPU machine where GPUs 0–3 are already in use:
382
```
383
dynamo-run out=sglang <model> --tensor-parallel-size 4 --base-gpu-id 4
384
385
```

386
**Multinode:**
387

388
Dynamo only manages the leader node (node rank 0). The follower nodes are started in the [normal sglang way](https://docs.sglang.ai/references/deepseek.html#running-examples-on-multi-node).
389

390
Leader node:
391
```
392
393
394
395
396
397
dynamo-run out=sglang /data/models/DeepSeek-R1-Distill-Llama-70B/ --tensor-parallel-size 16 --node-rank 0 --num-nodes 2 --leader-addr 10.217.98.122:5000
```

All follower nodes. Increment `node-rank` each time:
```
python3 -m sglang.launch_server --model-path /data/models/DeepSeek-R1-Distill-Llama-70B --tp 16 --dist-init-addr 10.217.98.122:5000 --nnodes 2 --node-rank 1 --trust-remote-code
398
```
399
400
401
402

- Parameters `--leader-addr` and `--dist-init-addr` must match and be the IP address of the leader node. All followers must be able to connect. SGLang is using [PyTorch Distributed](https://docs.pytorch.org/tutorials/beginner/dist_overview.html) for networking.
- Parameters `--tensor-parallel-size` and `--tp` must match and be the total number of GPUs across the cluster.
- `--node-rank` must be unique consecutive integers starting at 1. The leader, managed by Dynamo, is 0.
Graham King's avatar
Graham King committed
403

404
#### vllm
Graham King's avatar
Graham King committed
405

406
407
Using the [vllm](https://github.com/vllm-project/vllm) Python library. Slow startup, fast inference. Supports both safetensors from HF and GGUF files, but is very slow for GGUF - prefer llamacpp.

408
The vllm engine requires [etcd](https://etcd.io/) and [nats](https://nats.io/) with jetstream (`nats-server -js`) to be running.
Graham King's avatar
Graham King committed
409
410
411

We use [uv](https://docs.astral.sh/uv/) but any virtualenv manager should work.

412
1. Setup:
Graham King's avatar
Graham King committed
413
414
415
416
```
uv venv
source .venv/bin/activate
uv pip install pip
417
uv pip install vllm==0.8.4 setuptools
Graham King's avatar
Graham King committed
418
419
```

420
421
422
```{note}
If you're on Ubuntu 22.04 or earlier, you must add `--python=python3.10` to your `uv venv` command.
```
Graham King's avatar
Graham King committed
423

424
2. Build:
Graham King's avatar
Graham King committed
425
```
426
cargo build
427
cd target/debug
Graham King's avatar
Graham King committed
428
429
```

430
431
432
433
3. Run
Inside that virtualenv:

**HF repo:**
Graham King's avatar
Graham King committed
434
```
435
./dynamo-run in=http out=vllm ~/llms/Llama-3.2-3B-Instruct/
Graham King's avatar
Graham King committed
436
437

```
438

439
To pass extra arguments to the vllm engine see [Extra engine arguments](#extra-engine-arguments).
440

441
442
vllm attempts to allocate enough KV cache for the full context length at startup. If that does not fit in your available memory pass `--context-length <value>`.

443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
If you see an error similar to the following:
```text
2025-06-28T00:32:32.507Z  WARN dynamo_run::subprocess: Traceback (most recent call last):
2025-06-28T00:32:32.507Z  WARN dynamo_run::subprocess:   File "/tmp/.tmpYeq5qA", line 29, in <module>
2025-06-28T00:32:32.507Z  WARN dynamo_run::subprocess:     from dynamo.llm import ModelType, WorkerMetricsPublisher, register_llm
2025-06-28T00:32:32.507Z  WARN dynamo_run::subprocess: ModuleNotFoundError: No module named 'dynamo'
```
Then run
```
uv pip install maturin
pip install patchelf
cd lib/bindings/python
maturin develop
```
this builds the Python->Rust bindings into that missing dynamo module. Rerun dynamo-run, the problem should be resolved.

459
**Multi-GPU**
460

461
Pass `--tensor-parallel-size <NUM-GPUS>` to `dynamo-run`.
462

463
To specify which GPUs to use set environment variable `CUDA_VISIBLE_DEVICES`.
464

465
**Multinode:**
466

467
vllm uses [ray](https://docs.vllm.ai/en/latest/serving/distributed_serving.html#running-vllm-on-multiple-nodes) for pipeline parallel inference. Dynamo does not change or manage that.
468

469
470
Here is an example on two 8x nodes:
- Leader node: `ray start --head --port=6379`
471
- Each follower node: `ray start --address=<HEAD_NODE_IP>:6379`
472
- Leader node: `dynamo-run out=vllm ~/llms/DeepSeek-R1-Distill-Llama-70B/ --tensor-parallel-size 16`
473

474
The `--tensor-parallel-size` parameter is the total number of GPUs in the cluster. This is often constrained by a model dimension such as being a divisor of the number of attention heads.
475

476
Startup can be slow so you may want to `export DYN_LOG=debug` to see progress.
477

478
Shutdown: `ray stop`
479

480
#### trtllm
481

482
483
484
485
Using [TensorRT-LLM's LLM API](https://nvidia.github.io/TensorRT-LLM/llm-api/), a high-level Python API.

You can use `--extra-engine-args` to pass extra arguments to LLM API engine.

486
The trtllm engine requires [etcd](https://etcd.io/) and [nats](https://nats.io/) with jetstream (`nats-server -js`) to be running.
487
488
489

##### Step 1: Build the environment

490
See instructions [here](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/README.md#build-docker) to build the dynamo container with TensorRT-LLM.
491
492
493

##### Step 2: Run the environment

494
See instructions [here](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/README.md#run-container) to run the built environment.
495

atchernych's avatar
atchernych committed
496
##### Step 3: Execute `dynamo-run` command
497
498
499

Execute the following to load the TensorRT-LLM model specified in the configuration.
```
500
dynamo-run in=http out=trtllm TinyLlama/TinyLlama-1.1B-Chat-v1.0
501
502
```

503
#### Echo Engines
504
505
506

Dynamo includes two echo engines for testing and debugging purposes:

507
##### echo_core
508

509
The `echo_core` engine accepts pre-processed requests and echoes the tokens back as the response. This is useful for testing pre-processing functionality as the response includes the full prompt template.
510
511
512
513
514

```
dynamo-run in=http out=echo_core --model-path <hf-repo-checkout>
```

515
516
517
518
519
520
521
Note that to use it with `in=http` you need to tell the post processor to ignore stop tokens from the template by adding `nvext.ignore_eos` like this:
```
curl -N -d '{"nvext": {"ignore_eos": true}, "stream": true, "model": "Qwen2.5-3B-Instruct", "max_completion_tokens": 4096, "messages":[{"role":"user", "content": "Tell me a story" }]}' ...
```

The default `in=text` sets that for you.

522
##### echo_full
523
524
525
526

The `echo_full` engine accepts un-processed requests and echoes the prompt back as the response.

```
527
dynamo-run in=http out=echo_full --model-name my_model
528
529
```

530
##### Configuration
531
532
533
534
535
536
537
538
539

Both echo engines use a configurable delay between tokens to simulate generation speed. You can adjust this using the `DYN_TOKEN_ECHO_DELAY_MS` environment variable:

```
# Set token echo delay to 1ms (1000 tokens per second)
DYN_TOKEN_ECHO_DELAY_MS=1 dynamo-run in=http out=echo_full
```

The default delay is 10ms, which produces approximately 100 tokens per second.
540

541
#### Batch mode
542

543
`dynamo-run` can take a jsonl file full of prompts and evaluate them all:
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561

```
dynamo-run in=batch:prompts.jsonl out=llamacpp <model>
```

The input file should look like this:
```
{"text": "What is the capital of France?"}
{"text": "What is the capital of Spain?"}
```

Each one is passed as a prompt to the model. The output is written back to the same folder in `output.jsonl`. At the end of the run some statistics are printed.
The output looks like this:
```
{"text":"What is the capital of France?","response":"The capital of France is Paris.","tokens_in":7,"tokens_out":7,"elapsed_ms":1566}
{"text":"What is the capital of Spain?","response":".The capital of Spain is Madrid.","tokens_in":7,"tokens_out":7,"elapsed_ms":855}
```

562
563
564
565
566
567
568
569
570
571
572
#### Mocker engine

The mocker engine is a mock vLLM implementation designed for testing and development purposes. It simulates realistic token generation timing without requiring actual model inference, making it useful for:

- Testing distributed system components without GPU resources
- Benchmarking infrastructure and networking overhead
- Developing and debugging Dynamo components
- Load testing and performance analysis

**Basic usage:**

573
The `--model-path` is required but can point to any valid model path - the mocker doesn't actually load the model weights (but the pre-processor needs the tokenizer). The arguments `block_size`, `num_gpu_blocks`, `max_num_seqs`, `max_num_batched_tokens`, `enable_prefix_caching`, and `enable_chunked_prefill` are common arguments shared with the real VLLM engine.
574
575
576
577
578
579
580
581
582

And below are arguments that are mocker-specific:
- `speedup_ratio`: Speed multiplier for token generation (default: 1.0). Higher values make the simulation engines run faster.
- `dp_size`: Number of data parallel workers to simulate (default: 1)
- `watermark`: KV cache watermark threshold as a fraction (default: 0.01). This argument also exists for the real VLLM engine but cannot be passed as an engine arg.

```bash
echo '{"speedup_ratio": 10.0}' > mocker_args.json
dynamo-run in=dyn://dynamo.mocker.generate out=mocker --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 --extra-engine-args mocker_args.json
583
dynamo-run in=http out=auto --router-mode kv
584
585
```

586
587
588
589
590
591
592
593
594
595
596
597
598
### Extra engine arguments
The vllm and sglang backends support passing any argument the engine accepts.
Put the arguments in a JSON file:
```
{
    "dtype": "half",
    "trust_remote_code": true
}
```
Pass it like this:
```
dynamo-run out=sglang ~/llms/Llama-3.2-3B-Instruct --extra-engine-args sglang_extra.json
```
599

600
The tensorrtllm backend also supports passing any argument the engine accepts. However, in this case config should be a yaml file.
601
602
603
604
605
606
607
608
609
610
611
612

```
backend: pytorch
kv_cache_config:
  event_buffer_max_size: 1024
```

Pass it like this:
```
dynamo-run in=http out=trtllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 --extra-engine-args trtllm_extra.yaml
```

613
### Writing your own engine in Python
614
615
616
617
618
619
620
621
622
623
624
625

The [dynamo](https://pypi.org/project/ai-dynamo/) Python library allows you to build your own engine and attach it to Dynamo.

The Python file must do three things:
1. Decorate a function to get the runtime
2. Register on the network
3. Attach a request handler

```
from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker

626
627
628
629
   # 1. Decorate a function to get the runtime
   #
   @dynamo_worker(static=False)
   async def worker(runtime: DistributedRuntime):
630
631
632
633
634

    # 2. Register ourselves on the network
    #
    component = runtime.namespace("namespace").component("component")
    await component.create_service()
635
    model_path = "Qwen/Qwen3-0.6B" # or "/data/models/Qwen3-0.6B"
636
637
    model_type = ModelType.Backend
    endpoint = component.endpoint("endpoint")
638
639
    # Optional last param to register_llm is model_name. If not present derives it from model_path
    await register_llm(model_type, endpoint, model_path)
640
641
642
643
644
645

    # Initialize your engine here
    # engine = ...

    # 3. Attach request handler
    #
646
    await endpoint.serve_endpoint(RequestHandler(engine).generate)
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664

class RequestHandler:

    def __init__(self, engine):
        ...

    async def generate(self, request):
        # Call the engine
        # yield result dict
        ...

if __name__ == "__main__":
    uvloop.install()
    asyncio.run(worker())
```


The `model_path` can be:
665
- A HuggingFace repo ID, optionally prefixed with `hf://`. It is downloaded and cached locally.
666
667
668
669
670
671
672
673
- The path to a checkout of a HuggingFace repo - any folder containing safetensor files as well as `config.json`, `tokenizer.json` and `tokenizer_config.json`.
- The path to a GGUF file, if your engine supports that.

The `model_type` can be:
- ModelType.Backend. Dynamo handles pre-processing. Your `generate` method receives a `request` dict containing a `token_ids` array of int. It must return a dict also containing a `token_ids` array and an optional `finish_reason` string.
- ModelType.Chat. Your `generate` method receives a `request` and must return a response dict of type [OpenAI Chat Completion](https://platform.openai.com/docs/api-reference/chat). Your engine handles pre-processing.
- ModelType.Completion. Your `generate` method receives a `request` and must return a response dict of the older [Completions](https://platform.openai.com/docs/api-reference/completions). Your engine handles pre-processing.

674
675
676
677
`register_llm` can also take the following kwargs:
- `model_name`: The name to call the model. Your incoming HTTP requests model name must match this. Defaults to the hugging face repo name, the folder name, or the GGUF file name.
- `context_length`: Max model length in tokens. Defaults to the model's set max. Only set this if you need to reduce KV cache allocation to fit into VRAM.
- `kv_cache_block_size`: Size of a KV block for the engine, in tokens. Defaults to 16.
678
- `user_data`: Optional dictionary containing custom metadata for worker behavior (e.g., LoRA configuration). Defaults to None.
679

680
Here are some example engines:
681
682
683
684
685
686
687
688

- Backend:
    * [vllm](https://github.com/ai-dynamo/dynamo/blob/main/lib/bindings/python/examples/hello_world/server_vllm.py)
    * [sglang](https://github.com/ai-dynamo/dynamo/blob/main/lib/bindings/python/examples/hello_world/server_sglang.py)
- Chat:
    * [sglang](https://github.com/ai-dynamo/dynamo/blob/main/lib/bindings/python/examples/hello_world/server_sglang_tok.py)

More fully-featured Backend engines (used by `dynamo-run`):
689
690
691
- [vllm](https://github.com/ai-dynamo/dynamo/blob/main/launch/dynamo-run/src/subprocess/vllm_inc.py)
- [sglang](https://github.com/ai-dynamo/dynamo/blob/main/launch/dynamo-run/src/subprocess/sglang_inc.py)

692
693
694
695
696
697
698
699
700
### Debugging

`dynamo-run` and `dynamo-runtime` support [tokio-console](https://github.com/tokio-rs/console). Build with the feature to enable:
```
cargo build --features cuda,tokio-console -p dynamo-run
```

The listener uses the default tokio console port, and all interfaces (0.0.0.0).