multimodal.md 30.7 KB
Newer Older
1
(supports-multimodal)=
2

3
# Multi-Modal Support
4

5
This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs](#multimodal-inputs).
6
7
8

## 1. Update the base vLLM model

9
It is assumed that you have already implemented the model in vLLM according to [these steps](#new-model-basic).
10
11
Further update the model as follows:

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
- Reserve a keyword parameter in {meth}`~torch.nn.Module.forward` for each input tensor that corresponds to a multi-modal input, as shown in the following example:

  ```diff
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
  +     pixel_values: torch.Tensor,
    ) -> SamplerOutput:
  ```
  
  More conveniently, you can simply pass `**kwargs` to the {meth}`~torch.nn.Module.forward` method and retrieve the keyword parameters for multimodal inputs from it.

- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings` that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.

    ```python
    class YourModelForImage2Seq(nn.Module):
        ...

        def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:

            assert self.vision_encoder is not None
            image_features = self.vision_encoder(image_input)
            return self.multi_modal_projector(image_features)

37
38
        def get_multimodal_embeddings(
                self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
39
40
41
42
43
44
45
46
47
48
49

            # Validate the multimodal input keyword arguments
            image_input = self._parse_and_validate_image_input(**kwargs)
            if image_input is None:
                return None

            # Run multimodal inputs through encoder and projector
            vision_embeddings = self._process_image_input(image_input)
            return vision_embeddings
    ```

50
    :::{important}
51
    The returned `multimodal_embeddings` must be either a **3D {class}`torch.Tensor`** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D {class}`torch.Tensor`'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
52
    :::
53
54
55
56
57
58
59
60
61
62
63
64

- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings` to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.

    ```python
    from .utils import merge_multimodal_embeddings

    class YourModelForImage2Seq(nn.Module):
        ...

        def get_input_embeddings(
            self,
            input_ids: torch.Tensor,
65
            multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        ) -> torch.Tensor:

            # `get_input_embeddings` should already be implemented for the language 
            # model as one of the requirements of basic vLLM model implementation.
            inputs_embeds = self.language_model.get_input_embeddings(input_ids)

            if multimodal_embeddings is not None:
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids=input_ids, 
                    inputs_embeds=inputs_embeds, 
                    multimodal_embeddings=multimodal_embeddings,
                    placeholder_token_id=self.config.image_token_index)

            return inputs_embeds
    ```

82
83
84
85
86
87
88
89
90
91
92
- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model` getter to provide stable access to the underlying language model.

    ```python
    class YourModelForImage2Seq(nn.Module):
        ...

        def get_language_model(self) -> torch.nn.Module:
            # Change `language_model` according to your implementation.
            return self.language_model
    ```

93
- Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
94
95
96
97
98
99
100
101

  ```diff
  + from vllm.model_executor.models.interfaces import SupportsMultiModal

  - class YourModelForImage2Seq(nn.Module):
  + class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
  ```

102
  :::{note}
103
104
  The model class does not have to be named {code}`*ForCausalLM`.
  Check out [the HuggingFace Transformers documentation](https://huggingface.co/docs/transformers/model_doc/auto#multimodal) for some examples.
105
  :::
106

107
## 2. Specify processing information
108

109
110
Next, create a subclass of {class}`~vllm.multimodal.processing.BaseProcessingInfo`
to provide basic information related to HF processing.
111

112
### Maximum number of input items
113

114
115
116
117
118
119
120
121
You need to override the abstract method {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_supported_mm_limits`
to return the maximum number of input items for each modality supported by the model.

For example, if the model supports any number of images but only one video per prompt:

```python
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
    return {"image": None, "video": 1}
122
123
```

124
## 3. Specify dummy inputs
125

126
127
128
129
130
131
132
133
Then, inherit {class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` to construct dummy inputs for
HF processing as well as memory profiling.

### For memory profiling

Override the abstract method {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`
to construct dummy inputs for memory profiling. This dummy input should result in the worst-case memory usage of
the model so that vLLM can reserve the correct amount of memory for it.
134

135
Assuming that the memory usage increases with the number of tokens, the dummy input can be constructed to maximize the number of output embeddings, which is the same number as placeholder feature tokens.
136

137
138
::::{tab-set}
:::{tab-item} Basic example: LLaVA
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
:sync: llava

Looking at the code of HF's `LlavaForConditionalGeneration`:

```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L530-L544
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] * image_features.shape[1]

if n_image_tokens != n_image_features:
    raise ValueError(
        f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
    )
special_image_mask = (
    (input_ids == self.config.image_token_index)
    .unsqueeze(-1)
    .expand_as(inputs_embeds)
    .to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
```
161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
The number of placeholder feature tokens per image is `image_features.shape[1]`.
`image_features` is calculated inside the `get_image_features` method:

```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L290-L300
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)

selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
    selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
    selected_image_feature = selected_image_feature
else:
    raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
178
179
```

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
We can infer that `image_features.shape[1]` is based on `image_outputs.hidden_states.shape[1]` from the vision tower
(`CLIPVisionModel` for the [`llava-hf/llava-1.5-7b-hf`](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model).
Moreover, we only need the sequence length (the second dimension of the tensor) to get `image_features.shape[1]`.
The sequence length is determined by the initial hidden states in `CLIPVisionTransformer` since the attention
mechanism doesn't change the sequence length of the output hidden states.

```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L1094-L1102
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)

encoder_outputs = self.encoder(
    inputs_embeds=hidden_states,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    return_dict=return_dict,
)
```
198

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
To find the sequence length, we turn to the code of `CLIPVisionEmbeddings`:

```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L247-L257
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
if interpolate_pos_encoding:
    embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
    embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
```
215

216
We can infer that `embeddings.shape[1] == self.num_positions`, where
217

218
219
220
221
```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L195-L196
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
222
223
```

224
Overall, the number of placeholder feature tokens for an image can be calculated as:
225

226
227
228
229
230
231
232
233
234
```python
def get_num_image_tokens(
    self,
    *,
    image_width: int,
    image_height: int,
) -> int:
    hf_config = self.get_hf_config()
    hf_processor = self.get_hf_processor()
235

236
237
238
239
240
241
242
243
    image_size = hf_config.vision_config.image_size
    patch_size = hf_config.vision_config.patch_size

    num_image_tokens = (image_size // patch_size) ** 2 + 1
    if hf_processor.vision_feature_select_strategy == "default":
        num_image_tokens -= 1

    return num_image_tokens
244
245
```

246
Notice that the number of image tokens doesn't depend on the image width and height.
247
We can simply use a dummy `image_size`:
248

249
250
251
252
253
```python
def get_image_size_with_most_features(self) -> ImageSize:
    hf_config = self.get_hf_config()
    width = height = hf_config.image_size
    return ImageSize(width=width, height=height)
254

255
def get_dummy_processor_inputs(
256
257
258
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
) -> ProcessorInputs:
    num_images = mm_counts.get("image", 0)

    processor = self.info.get_hf_processor()
    image_token = processor.image_token
  
    hf_config = self.get_hf_config()
    target_width, target_height = self.info.get_image_size_with_most_features()

    mm_data = {
        "image":
        self._get_dummy_images(width=target_width,
                               height=target_height,
                               num_images=num_images)
    }

    return ProcessorInputs(
        prompt_text=image_token * num_images,
        mm_data=mm_data,
    )
279
280
```

281
:::
282

283
:::{tab-item} No input placeholders: Fuyu
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
:sync: fuyu

Looking at the code of HF's `FuyuForCausalLM`:

```python
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/modeling_fuyu.py#L311-L322
if image_patches is not None and past_key_values is None:
    patch_embeddings = [
        self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype))
        .squeeze(0)
        .to(inputs_embeds.device)
        for patch in image_patches
    ]
    inputs_embeds = self.gather_continuous_embeddings(
        word_embeddings=inputs_embeds,
        continuous_embeddings=patch_embeddings,
        image_patch_input_indices=image_patches_indices,
    )
```

The number of placeholder feature tokens for the `i`th item in the batch is `patch_embeddings[i].shape[0]`,
which is the same as `image_patches[i].shape[0]`, i.e. `num_total_patches`.

Unlike LLaVA, Fuyu does not define the number of patches inside the modeling file. Where can we get more information?
Considering that the model input comes from the output of `FuyuProcessor`, let's **look at the preprocessing files**.

The image outputs are obtained by calling `FuyuImageProcessor.preprocess` and then
`FuyuImageProcessor.preprocess_with_tokenizer_info` inside `FuyuProcessor`.

In `FuyuImageProcessor.preprocess`, the images are resized and padded to the target `FuyuImageProcessor.size`,
returning the dimensions after resizing (but before padding) as metadata.

```python
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L541-L544
image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"])
batch_images = image_encoding["images"]
image_unpadded_heights = image_encoding["image_unpadded_heights"]
image_unpadded_widths = image_encoding["image_unpadded_widths"]

# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L480-L
if do_resize:
    batch_images = [
        [self.resize(image, size=size, input_data_format=input_data_format) for image in images]
        for images in batch_images
    ]

image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
image_unpadded_heights = [[image_size[0]] for image_size in image_sizes]
image_unpadded_widths = [[image_size[1]] for image_size in image_sizes]

if do_pad:
    batch_images = [
        [
            self.pad_image(
                image,
                size=size,
                mode=padding_mode,
                constant_values=padding_value,
                input_data_format=input_data_format,
            )
            for image in images
        ]
        for images in batch_images
    ]
```

In `FuyuImageProcessor.preprocess_with_tokenizer_info`, the images are split into patches based on this metadata:

```python
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L425
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
    image_input=tensor_batch_images,
    image_present=image_present,
    image_unpadded_h=image_unpadded_heights,
    image_unpadded_w=image_unpadded_widths,
    image_placeholder_id=image_placeholder_id,
    image_newline_id=image_newline_id,
    variable_sized=True,
)

# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L638-L658
image_height, image_width = image.shape[1], image.shape[2]
if variable_sized:  # variable_sized=True
    new_h = min(
        image_height,
        math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height,
    )
    new_w = min(
        image_width,
        math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width,
    )
    image = image[:, :new_h, :new_w]
    image_height, image_width = new_h, new_w

num_patches = self.get_num_patches(image_height=image_height, image_width=image_width)
tensor_of_image_ids = torch.full(
    [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device
)
patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0)
assert num_patches == patches.shape[0]
```

The number of patches is in turn defined by `FuyuImageProcessor.get_num_patches`:

```python
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L552-L562
patch_size = patch_size if patch_size is not None else self.patch_size
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]

if image_height % patch_height != 0:
    raise ValueError(f"{image_height=} must be divisible by {patch_height}")
if image_width % patch_width != 0:
    raise ValueError(f"{image_width=} must be divisible by {patch_width}")

num_patches_per_dim_h = image_height // patch_height
num_patches_per_dim_w = image_width // patch_width
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
```

403
404
These image patches correspond to placeholder tokens (`|SPEAKER|`). So, we just need to maximize the number of image patches. Since input images are first resized
to fit within `image_processor.size`, we can maximize the number of image patches by inputting an image with size equal to `image_processor.size`.
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441

```python
def get_image_size_with_most_features(self) -> ImageSize:
    image_processor = self.get_image_processor()
    return ImageSize(width=image_processor.size["width"],
                        height=image_processor.size["height"])
```

Fuyu does not expect image placeholders in the inputs to HF processor, so
the dummy prompt text is empty regardless of the number of images.
Otherwise, the logic of this method is very similar to LLaVA:

```python
def get_dummy_processor_inputs(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
) -> ProcessorInputs:
    target_width, target_height = \
        self.info.get_image_size_with_most_features()
    num_images = mm_counts.get("image", 0)

    mm_data = {
        "image":
        self._get_dummy_images(width=target_width,
                                height=target_height,
                                num_images=num_images)
    }

    return ProcessorInputs(
        prompt_text="",
        mm_data=mm_data,
    )
```

:::

442
::::
443

444
## 4. Specify processing details
445

446
447
Afterwards, create a subclass of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`
to fill in the missing details about HF processing.
448

449
:::{seealso}
450
[Multi-Modal Data Processing](#mm-processing)
451
:::
452

453
454
### Multi-modal fields

455
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` to
456
457
return a schema of the tensors outputted by the HF processor that are related to the input multi-modal items.

458
459
:::::{tab-set}
::::{tab-item} Basic example: LLaVA
460
461
:sync: llava

462
463
The output of `CLIPImageProcessor` is a simple tensor with shape
`(num_images, num_channels, image_height, image_width)`:
464
465

```python
466
467
468
469
470
471
472
473
474
475
476
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/image_processing_clip.py#L339-L345
images = [
    to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
    for image in all_images
]

data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
```

So, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` as follows:
477
478
479
480
481
482
483
484
485
486
487

```python
def _get_mm_fields_config(
    self,
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
    return dict(
        pixel_values=MultiModalFieldConfig.batched("image"),
    )
```
488

489
:::{note}
490
491
492
Our [actual code](gh-file:vllm/model_executor/models/llava.py) additionally supports
pre-computed image embeddings, which can be passed to be model via the `image_embeds` argument.
:::
493

494
::::
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566

::::{tab-item} With postprocessing: Fuyu
:sync: fuyu

The `image_patches` output of `FuyuImageProcessor.preprocess_with_tokenizer_info` concatenates
the patches from each image belonging to an item in the batch:

```python
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L673-L679
        image_input_ids.append(tensor_of_image_ids)
        image_patches.append(patches)
    else:
        image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device))

batch_image_input_ids.append(image_input_ids)
batch_image_patches.append(image_patches)
```

The shape of `image_patches` outputted by `FuyuImageProcessor` is therefore
`(1, num_images, num_patches, patch_width * patch_height * num_channels)`.

In order to support the use of {func}`MultiModalFieldConfig.batched` like in LLaVA,
we remove the extra batch dimension by overriding {meth}`BaseMultiModalProcessor._call_hf_processor`:

```python
def _call_hf_processor(
    self,
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
) -> BatchFeature:
    processed_outputs = super()._call_hf_processor(
        prompt=prompt,
        mm_data=mm_data,
        mm_kwargs=mm_kwargs,
    )

    image_patches = processed_outputs.get("image_patches")
    if image_patches is not None:
        images = mm_data["images"]
        assert isinstance(images, list)

        # Original output: (1, num_images, Pn, Px * Py * C)
        # New output: (num_images, Pn, Px * Py * C)
        assert (isinstance(image_patches, list)
                and len(image_patches) == 1)
        assert (isinstance(image_patches[0], torch.Tensor)
                and len(image_patches[0]) == len(images))

        processed_outputs["image_patches"] = image_patches[0]

    return processed_outputs
```

:::{note}
Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) has special handling
for text-only inputs to prevent unnecessary warnings from HF processor.
:::

This lets us override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` as follows:

```python
def _get_mm_fields_config(
    self,
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
    return dict(image_patches=MultiModalFieldConfig.batched("image"))
```

::::

567
:::::
568

569
### Prompt updates
570

571
572
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` to
return a list of {class}`~vllm.multimodal.processing.PromptUpdate` instances.
573

574
575
Each {class}`~vllm.multimodal.processing.PromptUpdate` instance specifies an update operation
(e.g.: insertion, replacement) performed by the HF processor.
576
577
578
579
580
581
582
583
584
585
586
587
588

::::{tab-set}
:::{tab-item} Basic example: LLaVA
:sync: llava

Looking at HF's `LlavaProcessor`:

```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/processing_llava.py#L167-L170
prompt_strings = []
for sample in text:
    sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
    prompt_strings.append(sample)
589
590
```

591
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
592
Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` as follows:
593
594

```python
595
def _get_prompt_updates(
596
597
598
599
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargs,
600
) -> Sequence[PromptUpdate]:
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
    hf_config = self.info.get_hf_config()
    image_token_id = hf_config.image_token_index

    def get_replacement(item_idx: int):
        images = mm_items.get_items("image", ImageProcessorItems)

        image_size = images.get_image_size(item_idx)
        num_image_tokens = self.info.get_num_image_tokens(
            image_width=image_size.width,
            image_height=image_size.height,
        )

        return [image_token_id] * num_image_tokens

    return [
        PromptReplacement(
            modality="image",
            target=[image_token_id],
            replacement=get_replacement,
        ),
    ]
```
623

624
:::
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707

:::{tab-item} Handling additional tokens: Fuyu
:sync: fuyu

Recall the layout of feature tokens from Step 2:

```
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
...
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
```

We define a helper function to return `ncols` and `nrows` directly:

```python
def get_image_feature_grid_size(
    self,
    *,
    image_width: int,
    image_height: int,
) -> tuple[int, int]:
    image_processor = self.get_image_processor()
    target_width = image_processor.size["width"]
    target_height = image_processor.size["height"]
    patch_width = image_processor.patch_size["width"]
    patch_height = image_processor.patch_size["height"]

    if not (image_width <= target_width and image_height <= target_height):
        height_scale_factor = target_height / image_height
        width_scale_factor = target_width / image_width
        optimal_scale_factor = min(height_scale_factor, width_scale_factor)

        image_height = int(image_height * optimal_scale_factor)
        image_width = int(image_width * optimal_scale_factor)

    ncols = math.ceil(image_width / patch_width)
    nrows = math.ceil(image_height / patch_height)
    return ncols, nrows
```

Based on this, we can initially define our replacement tokens as:

```python
def get_replacement(item_idx: int):
    images = mm_items.get_items("image", ImageProcessorItems)
    image_size = images.get_image_size(item_idx)

    ncols, nrows = self.info.get_image_feature_grid_size(
        image_width=image_size.width,
        image_height=image_size.height,
    )

    # `_IMAGE_TOKEN_ID` corresponds to `|SPEAKER|`
    # `_NEWLINE_TOKEN_ID` corresponds to `|NEWLINE|`
    return ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows
```

However, this is not entirely correct. After `FuyuImageProcessor.preprocess_with_tokenizer_info` is called,
a BOS token (`<s>`) is also added to the promopt:

```python
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L435
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
    image_input=tensor_batch_images,
    image_present=image_present,
    image_unpadded_h=image_unpadded_heights,
    image_unpadded_w=image_unpadded_widths,
    image_placeholder_id=image_placeholder_id,
    image_newline_id=image_newline_id,
    variable_sized=True,
)
prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
    tokenizer=self.tokenizer,
    prompts=prompts,
    scale_factors=scale_factors,
    max_tokens_to_generate=self.max_tokens_to_generate,
    max_position_embeddings=self.max_position_embeddings,
    add_BOS=True,
    add_beginning_of_answer_token=True,
)
```

708
709
To assign the vision embeddings to only the image tokens, instead of a string
you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`:
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726

```python
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id  # `<s>`
assert isinstance(bos_token_id, int)

def get_replacement_fuyu(item_idx: int):
    images = mm_items.get_items("image", ImageProcessorItems)
    image_size = images.get_image_size(item_idx)

    ncols, nrows = self.info.get_image_feature_grid_size(
        image_width=image_size.width,
        image_height=image_size.height,
    )
    image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
                    [_NEWLINE_TOKEN_ID]) * nrows

727
728
729
    return PromptUpdateDetails.select_token_id(
        image_tokens + [bos_token_id],
        embed_token_id=_IMAGE_TOKEN_ID,
730
731
732
733
734
735
736
    )
```

Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the tokenized prompt,
we can search for it to conduct the replacement at the start of the string:

```python
737
def _get_prompt_updates(
738
739
740
741
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargs,
742
) -> Sequence[PromptUpdate]:
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
    hf_config = self.info.get_hf_config()
    bos_token_id = hf_config.bos_token_id
    assert isinstance(bos_token_id, int)

    tokenizer = self.info.get_tokenizer()
    eot_token_id = tokenizer.bos_token_id
    assert isinstance(eot_token_id, int)

    def get_replacement_fuyu(item_idx: int):
        images = mm_items.get_items("image", ImageProcessorItems)
        image_size = images.get_image_size(item_idx)

        ncols, nrows = self.info.get_image_feature_grid_size(
            image_width=image_size.width,
            image_height=image_size.height,
        )
        image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
                        [_NEWLINE_TOKEN_ID]) * nrows

762
763
764
        return PromptUpdateDetails.select_token_id(
            image_tokens + [bos_token_id],
            embed_token_id=_IMAGE_TOKEN_ID,
765
766
767
768
769
770
771
772
773
774
775
776
777
        )

    return [
        PromptReplacement(
            modality="image",
            target=[eot_token_id],
            replacement=get_replacement_fuyu,
        )
    ]
```

:::

778
::::
779

780
## 5. Register processor-related classes
781

782
783
784
785
786
787
788
789
790
791
792
793
794
795
After you have defined {class}`~vllm.multimodal.processing.BaseProcessingInfo` (Step 2),
{class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` (Step 3),
and {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` (Step 4),
decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_processor <vllm.multimodal.registry.MultiModalRegistry.register_processor>`
to register them to the multi-modal registry:

```diff
  from vllm.model_executor.models.interfaces import SupportsMultiModal
+ from vllm.multimodal import MULTIMODAL_REGISTRY

+ @MULTIMODAL_REGISTRY.register_processor(YourMultiModalProcessor,
+                                         info=YourProcessingInfo,
+                                         dummy_inputs=YourDummyInputsBuilder)
  class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
796
```
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828

## Notes

### Inserting feature tokens without replacement

Some HF processors directly insert feature tokens without replacing anything in the original prompt. In that case, you can use {class}`~vllm.multimodal.processing.PromptInsertion` instead of {class}`~vllm.multimodal.processing.PromptReplacement` inside {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`.

Examples:

- BLIP-2 (insert at start of prompt): <gh-file:vllm/model_executor/models/blip2.py>
- Florence2 (insert at start of prompt): <gh-file:vllm/model_executor/models/florence2.py>
- Molmo (insert after `<|endoftext|>` token): <gh-file:vllm/model_executor/models/molmo.py>

### Handling prompt updates unrelated to multi-modal data

{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` assumes that each application of prompt update corresponds to one multi-modal item. If the HF processor performs additional processing regardless of how many multi-modal items there are, you should override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_tokens_only` so that the processed token inputs are consistent with the result of applying the HF processor on text inputs. This is because token inputs bypass the HF processor according to [our design](#mm-processing).

Examples:

- Chameleon (appends `sep_token`): <gh-file:vllm/model_executor/models/chameleon.py>
- Fuyu (appends `boa_token`): <gh-file:vllm/model_executor/models/fuyu.py>
- Molmo (applies chat template which is not defined elsewhere): <gh-file:vllm/model_executor/models/molmo.py>

### Custom HF processor

Some models don't define a HF processor class on HF Hub. In that case, you can define a custom HF processor that has the same call signature as HF processors and pass it to {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor`.

Examples:

- DeepSeek-VL2: <gh-file:vllm/model_executor/models/deepseek_vl2.py>
- InternVL: <gh-file:vllm/model_executor/models/internvl.py>
- Qwen-VL: <gh-file:vllm/model_executor/models/qwen_vl.py>