multimodal-vllm.md 22.6 KB
Newer Older
1
2
3
---
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
4
title: vLLM Multimodal
5
6
7
8
---

This document provides a comprehensive guide for multimodal inference using vLLM backend in Dynamo.

9
10
11
12
<Warning>
**Security Requirement**: All multimodal workers require the `--enable-multimodal` flag to be explicitly set at startup. This is a security feature to prevent unintended processing of multimodal data from untrusted sources. Workers will fail at startup if multimodal flags (e.g., `--multimodal-worker`, `--multimodal-processor`) are used without `--enable-multimodal`.
This flag is analogous to `--enable-mm-embeds` in vllm serve but also extends it to all multimodal content (url, embeddings, b64).
</Warning>
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

## Support Matrix

| Modality | Input Format | Aggregated | Disaggregated | Notes |
|----------|--------------|------------|---------------|-------|
| **Image** | HTTP/HTTPS URL | Yes | Yes | Full support for all image models |
| **Image** | Data URL (Base64) | Yes | Yes | Inline base64-encoded images |
| **Video** | HTTP/HTTPS URL | Yes | Yes | Frame extraction and processing |
| **Audio** | HTTP/HTTPS URL | Yes | Yes | Experimental - requires audio dependencies |

### Supported URL Formats

| Format | Example | Description |
|--------|---------|-------------|
| **HTTP/HTTPS** | `http://example.com/image.jpg` | Remote media files |
| **Data URL** | `data:image/jpeg;base64,/9j/4AAQ...` | Base64-encoded inline data |

## Deployment Patterns

32
vLLM supports all multimodal deployment patterns. See [Architecture Patterns](README.md#architecture-patterns) for detailed explanations.
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

| Pattern | Supported | Launch Script | Notes |
|---------|-----------|---------------|-------|
| EPD (Simple Aggregated) | ✅ | `agg_multimodal.sh` | Easiest setup |
| E/PD (Encode Separate) | ✅ | `agg_multimodal_epd.sh` | Separate encode worker |
| E/P/D (Full Disaggregation) | ✅ | `disagg_multimodal_epd.sh` | All stages separate |
| EP/D (Traditional Disaggregated) | ✅ | `disagg_multimodal_llama.sh` | For Llama 4 models |

### Component Flags

| Component | Flag | Purpose |
|-----------|------|---------|
| Processor | `--multimodal-processor` | HTTP entry, tokenization |
| Encode Worker | `--multimodal-encode-worker` | Media encoding |
| PD Worker | `--multimodal-worker` | Prefill + Decode |
48
| Prefill Worker | `--multimodal-worker --disaggregation-mode prefill` | Prefill only |
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
| Decode Worker | `--multimodal-decode-worker` | Decode only |

## Use the Latest Release

We recommend using the latest stable release of dynamo to avoid breaking changes:

[![GitHub Release](https://img.shields.io/github/v/release/ai-dynamo/dynamo)](https://github.com/ai-dynamo/dynamo/releases/latest)

You can find the [latest release](https://github.com/ai-dynamo/dynamo/releases/latest) and check out the corresponding branch with:

```bash
git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
```

## Image Serving

### E/PD Serving (Encode Separate)

**Components:**

69
- workers: [EncodeWorkerHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py) for encoding and [DecodeWorkerHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/vllm/handlers.py) for prefilling and decoding.
70
71
72
73
74
- processor: Tokenizes the prompt and passes it to the EncodeWorkerHandler.
- frontend: HTTP endpoint to handle incoming requests.

**Workflow:**

75
The EncodeWorkerHandler encodes the image and passes the embeddings to the DecodeWorkerHandler via NATS and RDMA. The work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface.
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

```mermaid
flowchart LR
  HTTP --> processor
  processor --> HTTP
  processor --image_url--> encode_worker
  encode_worker --> processor
  encode_worker --embeddings--> pd_worker
  pd_worker --> encode_worker
```

> **Note:** Aggregated serving supports LLaVA 1.5 7B and Qwen2.5-VL-7B-Instruct. Disaggregated serving is currently only confirmed for LLaVA.

**Launch:**

```bash
cd $DYNAMO_HOME/examples/backends/vllm
# Serve a LLaVA 1.5 7B model:
bash launch/agg_multimodal_epd.sh --model llava-hf/llava-1.5-7b-hf
# Serve a Qwen2.5-VL model:
bash launch/agg_multimodal_epd.sh --model Qwen/Qwen2.5-VL-7B-Instruct
```

**Client:**

```bash
curl http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
      "model": "llava-hf/llava-1.5-7b-hf",
      "messages": [
        {
          "role": "user",
          "content": [
            {
              "type": "text",
              "text": "What is in this image?"
            },
            {
              "type": "image_url",
              "image_url": {
                "url": "http://images.cocodataset.org/test2017/000000155781.jpg"
              }
            }
          ]
        }
      ],
      "max_tokens": 300,
      "temperature": 0.0,
      "stream": false
    }'
```

### E/P/D Serving (Full Disaggregation)

**Components:**

133
- workers: [EncodeWorkerHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py) for encoding, [DecodeWorkerHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/vllm/handlers.py) for decoding, and [PrefillWorkerHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/vllm/handlers.py) for prefilling.
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
- processor: Tokenizes the prompt and passes it to the EncodeWorkerHandler.
- frontend: HTTP endpoint to handle incoming requests.

**Workflow:**

For the LLaVA model, embeddings are only required during the prefill stage. The EncodeWorkerHandler is connected directly to the prefill worker, encoding the image and passing embeddings via NATS and RDMA. The prefill worker performs the prefilling step and forwards the KV cache to the decode worker.

```mermaid
flowchart LR
  HTTP --> processor
  processor --> HTTP
  processor --image_url--> encode_worker
  encode_worker --> processor
  encode_worker --embeddings--> prefill_worker
  prefill_worker --> encode_worker
  prefill_worker --> decode_worker
  decode_worker --> prefill_worker
```

**Launch:**

```bash
cd $DYNAMO_HOME/examples/backends/vllm
bash launch/disagg_multimodal_epd.sh --model llava-hf/llava-1.5-7b-hf
```

160
161
162
<Note>
Disaggregation is currently only confirmed to work with LLaVA. Qwen2.5-VL is not confirmed to be supported.
</Note>
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

## Llama 4 Serving

The Llama 4 model family is natively multimodal. Unlike LLaVA, they do not directly consume image embeddings as input (see the [vLLM support matrix](https://docs.vllm.ai/en/latest/models/supported_models.html#text-generation_1)). Therefore, the encoder worker is not used and encoding is done alongside prefill.

Example model: `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8` on H100x8.

### Llama 4 Aggregated Serving

**Workflow:**

```mermaid
flowchart LR
  HTTP --> processor
  processor --> HTTP
  processor --image_url--> pd_worker
  pd_worker --> processor
```

**Launch:**

```bash
cd $DYNAMO_HOME/examples/backends/vllm
186
bash launch/agg_multimodal.sh --model meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
```

**Client:**

```bash
curl http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
      "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
      "messages": [
        {
          "role": "user",
          "content": [
            {
              "type": "text",
              "text": "What is in this image?"
            },
            {
              "type": "image_url",
              "image_url": {
                "url": "http://images.cocodataset.org/test2017/000000155781.jpg"
              }
            }
          ]
        }
      ],
      "max_tokens": 300,
      "temperature": 0.0,
      "stream": false
    }'
```

### Llama 4 Disaggregated Serving

**Workflow:**

```mermaid
flowchart LR
  HTTP --> processor
  processor --> HTTP
  processor --image_url--> prefill_worker
  prefill_worker --> processor
  prefill_worker --> decode_worker
  decode_worker --> prefill_worker
```

**Launch:**

```bash
cd $DYNAMO_HOME/examples/backends/vllm
bash launch/disagg_multimodal_llama.sh --head-node

# On a separate node with NATS_SERVER and ETCD_ENDPOINTS pointing to head node:
cd $DYNAMO_HOME/examples/backends/vllm
bash launch/disagg_multimodal_llama.sh
```

## Video Serving

### Video Aggregated Serving

**Components:**

- workers: [VideoEncodeWorker](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal/components/video_encode_worker.py) for decoding video into frames, and [VllmPDWorker](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal/components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the VideoEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.

**Workflow:**

The VideoEncodeWorker decodes the video into frames. Unlike the image pipeline which generates embeddings, this pipeline passes raw frames directly to the VllmPDWorker via NATS and RDMA.

```mermaid
flowchart LR
  HTTP --> processor
  processor --> HTTP
  processor --video_url--> video_encode_worker
  video_encode_worker --> processor
  video_encode_worker --frames--> pd_worker
  pd_worker --> video_encode_worker
```

**Launch:**

```bash
cd $DYNAMO_HOME/examples/multimodal
bash launch/video_agg.sh
```

**Client:**

```bash
curl http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
      "model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
      "messages": [
        {
          "role": "user",
          "content": [
            {
              "type": "text",
              "text": "Describe the video in detail"
            },
            {
              "type": "video_url",
              "video_url": {
                "url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
              }
            }
          ]
        }
      ],
      "max_tokens": 300,
      "stream": false
    }' | jq
```

### Video Disaggregated Serving

**Workflow:**

For the LLaVA-NeXT-Video-7B model, frames are only required during the prefill stage. The VideoEncodeWorker is connected directly to the prefill worker, decoding the video into frames and passing them via RDMA.

```mermaid
flowchart LR
  HTTP --> processor
  processor --> HTTP
  processor --video_url--> video_encode_worker
  video_encode_worker --> processor
  video_encode_worker --frames--> prefill_worker
  prefill_worker --> video_encode_worker
  prefill_worker --> decode_worker
  decode_worker --> prefill_worker
```

**Launch:**

```bash
cd $DYNAMO_HOME/examples/multimodal
bash launch/video_disagg.sh
```

## Audio Serving

### Audio Aggregated Serving

**Components:**

- workers: [AudioEncodeWorker](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal/components/audio_encode_worker.py) for decoding audio into embeddings, and [VllmPDWorker](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal/components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the AudioEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.

**Workflow:**

```mermaid
flowchart LR
  HTTP --> processor
  processor --> HTTP
  processor --audio_url--> audio_encode_worker
  audio_encode_worker --> processor
  audio_encode_worker --embeddings--> pd_worker
  pd_worker --> audio_encode_worker
```

**Launch:**

```bash
354
pip install 'vllm[audio]' accelerate # multimodal audio models dependency
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
cd $DYNAMO_HOME/examples/multimodal
bash launch/audio_agg.sh
```

**Client:**

```bash
curl http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
      "model": "Qwen/Qwen2-Audio-7B-Instruct",
      "messages": [
        {
          "role": "user",
          "content": [
            {
              "type": "text",
              "text": "What is recited in the audio?"
            },
            {
              "type": "audio_url",
              "audio_url": {
                "url": "https://raw.githubusercontent.com/yuekaizhang/Triton-ASR-Client/main/datasets/mini_en/wav/1221-135766-0002.wav"
              }
            }
          ]
        }
      ],
      "max_tokens": 6000,
      "temperature": 0.8,
      "stream": false
    }' | jq
```

### Audio Disaggregated Serving

**Workflow:**

For the Qwen2-Audio model, audio embeddings are only required during the prefill stage. The AudioEncodeWorker is connected directly to the prefill worker.

```mermaid
flowchart LR
  HTTP --> processor
  processor --> HTTP
  processor --audio_url--> audio_encode_worker
  audio_encode_worker --> processor
  audio_encode_worker --embeddings--> prefill_worker
  prefill_worker --> audio_encode_worker
  prefill_worker --> decode_worker
  decode_worker --> prefill_worker
```

**Launch:**

```bash
410
pip install 'vllm[audio]' accelerate # multimodal audio models dependency
411
412
413
414
cd $DYNAMO_HOME/examples/multimodal
bash launch/audio_disagg.sh
```

415
416
417
418
419
420
## Embedding Cache

Dynamo supports embedding cache in both aggregated and disaggregated settings:

| Setting | Implementation | Launch Script |
|---------|---------------|---------------|
421
| **Disaggregated encoder** | Dynamo-managed cache in the worker layer on top of vLLM engine | `disagg_multimodal_e_pd.sh` |
422
| **Aggregated** | Experimental via vLLM git patches | N/A |
423

424
### Aggregated Worker
425

426
A single vLLM instance caches encoded embeddings on CPU so repeated images skip encoding entirely. Experimental — requires vLLM patches (see below).
427
428

```mermaid
429
430
431
---
title: Embedding Cache — Aggregated Encoder (e.g. aggregated EP or EPD node)
---
432
flowchart LR
433
434
435
436
437
438
  req[Multimodal Request] --> gpu{GPU Encoder Cache<br/>hit?}
  gpu -- yes --> skip[Use cached GPU embedding<br/>no encoder, no connector]
  gpu -- no --> cpu{CPU Embedding Cache<br/>hit?}
  cpu -- yes --> load[Load: CPU → GPU<br/>skip encoder]
  cpu -- no --> encode[Run Encoder]
  encode -- save: GPU → CPU --> store[(CPU Embedding Cache<br/>LRU)]
439
440
441
442
443
```

**Launch:**

```bash
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462

cd /opt/dynamo/venv/lib/python3.12/site-packages

curl -sL https://github.com/vllm-project/vllm/pull/34182.diff | patch -p1

curl -sL https://github.com/vllm-project/vllm/pull/34783.diff | python3 -c "
import sys
chunks = sys.stdin.read().split('diff --git ')
filtered = [c for c in chunks if c.startswith('a/vllm/')]
print(''.join('diff --git ' + c for c in filtered))
" | patch -p1

vllm serve $model \
    --ec-transfer-config "{
        \"ec_role\": \"ec_both\",
        \"ec_connector\": \"DynamoMultimodalEmbeddingCacheConnector\",
        \"ec_connector_module_path\": \"dynamo.vllm.multimodal_utils.multimodal_embedding_cache_connector\",
        \"ec_connector_extra_config\": {\"multimodal_embedding_cache_capacity_gb\": 10}
    }"
463
464
465
466
```

This configures `vllm serve` with `ec_role=ec_both` and the `DynamoMultimodalEmbeddingCacheConnector` automatically. The capacity parameter controls the CPU-side LRU cache size in GB (0 = disabled).

467
### Disaggregated Encoder (Embedding Cache in Prefill Worker)
468

469
In the disaggregated setting, the Prefill Worker (P) owns a CPU-side LRU embedding cache (`EmbeddingCacheManager`). On each request P checks the cache first — on a hit, the Encode Worker is skipped entirely. On a miss, P routes to the Encode Worker (E), receives embeddings via NIXL, saves them to the cache, and then feeds the embeddings along with the request into the vLLM Instance for prefill.
470
471

```mermaid
472
473
474
---
title: Embedding Cache — Disaggregated Encoder
---
475
flowchart LR
476
477
478
479
480
481
482
483
484
485
    req[Request] --> cpu_check{"CPU cache hit?<br/>(EmbeddingCacheManager)"}

    subgraph P ["Prefill Worker (P)"]
        cpu_check -. hit .-> use[Use cached embedding]
        use --> vllm[vLLM Instance]
    end

    cpu_check -- miss --> E["Encode Worker (E)"]
    E -- "embeddings via NIXL" --> save["Save to cache"]
    save --> vllm
486
487
488
489
490
491
492
493
494
495
496
```

**Launch:**

```bash
cd $DYNAMO_HOME/examples/backends/vllm
bash launch/disagg_multimodal_e_pd.sh --multimodal-embedding-cache-capacity-gb 10
```

**Client:** Same as [E/PD Serving](#epd-serving-encode-separate)

497
498
499
500
501
502
503
504
## NIXL Usage

| Use Case | Script | NIXL Used? | Data Transfer |
|----------|--------|------------|---------------|
| EPD (Simple Aggregated) | `agg_multimodal.sh` | No | All in one worker |
| E/PD (Encode Separate) | `agg_multimodal_epd.sh` | Yes | Encoder → PD (embeddings) |
| E/P/D (Full Disaggregation) | `disagg_multimodal_epd.sh` | Yes | Encoder → Prefill (embeddings), Prefill → Decode (KV cache) |
| EP/D (Llama 4) | `disagg_multimodal_llama.sh` | Yes | Prefill → Decode (KV cache) |
505
| EC Both (Local Node) | `vllm_serve_embedding_cache.sh` | No | ECConnector via CPU Embedding Cache |
506
507
508
509
510
511
512
513
514
515
516
517
518
519

## ModelInput Types and Registration

Dynamo's Rust SDK supports two input types that determine how the HTTP frontend preprocesses requests:

| ModelInput Type | Preprocessing | Use Case |
|-----------------|---------------|----------|
| `ModelInput.Text` | None (raw text passed through) | Components that tokenize themselves |
| `ModelInput.Tokens` | Rust SDK would tokenize (but bypassed in multimodal) | Components expecting pre-tokenized input |

**Registration Pattern:**

```python
# Processor - Entry point from HTTP frontend
520
await register_model(
521
522
523
524
525
526
527
528
    ModelInput.Text,        # Frontend sends raw text
    ModelType.Chat,
    generate_endpoint,
    model_name,
    ...
)

# Workers - Internal components
529
await register_model(
530
531
532
533
534
535
536
537
    ModelInput.Tokens,      # Expect pre-tokenized input
    ModelType.Chat,         # or ModelType.Prefill for prefill workers
    generate_endpoint,
    model_name,
    ...
)
```

538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
## LoRA Adapters on Multimodal Workers

Multimodal workers support dynamic loading and unloading of LoRA adapters at runtime via the management API. This enables serving fine-tuned multimodal models alongside the base model.

### Loading a LoRA Adapter

Load an adapter on a running multimodal worker via the `load_lora` endpoint:

```bash
# For components workers (URI-based, requires DYN_LORA_ENABLED=true)
curl -X POST http://<worker-host>:<port>/load_lora \
  -H "Content-Type: application/json" \
  -d '{
    "lora_name": "my-vlm-adapter",
    "source": {"uri": "s3://my-bucket/adapters/my-vlm-adapter"}
  }'

# For example workers (path-based)
curl -X POST http://<worker-host>:<port>/load_lora \
  -H "Content-Type: application/json" \
  -d '{
    "lora_name": "my-vlm-adapter",
    "lora_path": "/path/to/adapter"
  }'
```

### Sending Requests with a LoRA

Set the `model` field in the request to the LoRA adapter name:

```bash
curl -X POST http://<frontend-host>:<port>/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "my-vlm-adapter",
    "messages": [
      {"role": "user", "content": [
        {"type": "text", "text": "Describe this image"},
        {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
      ]}
    ]
  }'
```

Requests without a LoRA name (or with the base model name) will use the base model.

### Unloading a LoRA Adapter

```bash
curl -X POST http://<worker-host>:<port>/unload_lora \
  -H "Content-Type: application/json" \
  -d '{"lora_name": "my-vlm-adapter"}'
```

### Listing Loaded Adapters

```bash
curl -X POST http://<worker-host>:<port>/list_loras
```

### Disaggregated Mode

In disaggregated (prefill/decode) deployments, the **same LoRA adapter must be loaded on both the prefill and decode workers**. The LoRA identity (`model` field) is automatically propagated from the prefill worker to the decode worker in the forwarded request.

```bash
# Load on prefill worker
curl -X POST http://<prefill-worker>/load_lora \
  -d '{"lora_name": "my-adapter", "source": {"uri": "s3://bucket/adapter"}}'

# Load on decode worker (same adapter)
curl -X POST http://<decode-worker>/load_lora \
  -d '{"lora_name": "my-adapter", "source": {"uri": "s3://bucket/adapter"}}'
```

If a LoRA is loaded on the prefill worker but not on the decode worker, the decode worker will fall back to the base model for that request.

614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
## Profiling

Dynamo's multimodal workers include NVTX markers for `nsys` profiling. They are disabled by default (zero overhead) and enabled by setting `DYN_NVTX=1`.

```bash
cd $DYNAMO_HOME/examples/backends/vllm
DYN_NVTX=1 nsys profile --trace=cuda,nvtx -o profile.nsys-rep \
    bash launch/agg_multimodal.sh ...
```

| ENV Variable | Default | Description |
|---|---|---|
| `DYN_NVTX` | `0` | Set to `1` to enable NVTX range/mark annotations in encode, prefill, and decode workers for `nsys` profiling |

Key NVTX ranges emitted:

| Range | Worker | Description |
|-------|--------|-------------|
| `mm:encode_worker_generate` | Encode | Full encode request lifetime |
| `mm:enc:cache_check` | Encode | Embedding cache lookup |
| `mm:enc:image_load` | Encode | Image download/load |
| `mm:enc:image_preprocess` | Encode | Image processor (CPU) |
| `mm:enc:vision_encode` | Encode | ViT + projector GPU forward |
| `mm:enc:embedding_transfer` | Encode | RDMA embedding staging |
| `mm:pd_worker_generate` | PD | Full PD request lifetime |
| `mm:pd:ttft` | PD | Worker-side TTFT: from request arrival at the PD worker to first output token (excludes client→frontend→worker network transit) |
| `mm:pd:load_multimodal` | PD | Fetch embeddings from encode worker |
| `mm:pd:disagg_prefill` | PD (disagg) | Prefill-only engine call |
| `mm:pd:disagg_remote_decode` | PD (disagg) | Remote decode round-trip |
| `mm:decode_worker_generate` | Decode | Full decode request lifetime |
| `mm:decode:first_token` | Decode | Time to first output token |

646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
## Known Limitations

- **Disaggregated flows require Python Processor** - All multimodal disaggregation requires the Python Processor component (`ModelInput.Text`).

## Supported Models

The following models have been tested with Dynamo's vLLM multimodal backend:

- **Qwen2.5-VL** - `Qwen/Qwen2.5-VL-7B-Instruct`
- **Qwen3-VL** - `Qwen/Qwen3-VL-30B-A3B-Instruct-FP8`
- **LLaVA 1.5** - `llava-hf/llava-1.5-7b-hf`
- **Llama 4 Maverick** - `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`
- **LLaVA Next Video** - `llava-hf/LLaVA-NeXT-Video-7B-hf`
- **Qwen2-Audio** - `Qwen/Qwen2-Audio-7B-Instruct`

For a complete list of multimodal models supported by vLLM, see [vLLM Supported Multimodal Models](https://docs.vllm.ai/en/latest/models/supported_models/#list-of-multimodal-language-models). Models listed there should work with Simple Aggregated Mode but may not be explicitly tested.

## Key Files

| File | Description |
|------|-------------|
| `components/src/dynamo/vllm/main.py` | Worker initialization and setup |
| `components/src/dynamo/vllm/args.py` | Command-line argument parsing |
| `components/src/dynamo/vllm/multimodal_handlers/processor_handler.py` | Processor implementation |
| `components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py` | Encode worker implementations (custom and vLLM-native) |
| `components/src/dynamo/vllm/multimodal_handlers/worker_handler.py` | PD/Prefill/Decode worker implementation |