"docker/Dockerfile.arm" did not exist on "95965d31b6ac2c9557816a6ffabe4a3117a5ccb2"
enabling_multimodal_inputs.rst 6.44 KB
Newer Older
1
.. _enabling_multimodal_inputs:
2

3
4
Enabling Multimodal Inputs
==========================
5

6
This document walks you through the steps to extend a vLLM model so that it accepts :ref:`multi-modal inputs <multimodal_inputs>`.
7

8
9
.. seealso::
    :ref:`adding_a_new_model`
10
11


12
1. Update the base vLLM model
13
14
-----------------------------

15
16
It is assumed that you have already implemented the model in vLLM according to :ref:`these steps <adding_a_new_model>`.
Further update the model as follows:
17

18
- Implement the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
19
20
21

  .. code-block:: diff

22
      + from vllm.model_executor.models.interfaces import SupportsMultiModal
23
24

      - class YourModelForImage2Seq(nn.Module):
25
      + class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
26
27
28
29
30

  .. note::
      The model class does not have to be named :code:`*ForCausalLM`.
      Check out `the HuggingFace Transformers documentation <https://huggingface.co/docs/transformers/model_doc/auto#multimodal>`__ for some examples.

31
- If you haven't already done so, reserve a keyword parameter in :meth:`~torch.nn.Module.forward`
32
33
34
35
  for each input tensor that corresponds to a multi-modal input, as shown in the following example:

  .. code-block:: diff

36
37
38
39
40
41
42
43
        def forward(
            self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
            kv_caches: List[torch.Tensor],
            attn_metadata: AttentionMetadata,
      +     pixel_values: torch.Tensor,
        ) -> SamplerOutput:
44
45
46
47
48


2. Register input mappers
-------------------------

49
For each modality type that the model accepts as input, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
50
51
52
53
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`.

.. code-block:: diff

54
      from vllm.model_executor.models.interfaces import SupportsMultiModal
55
56
    + from vllm.multimodal import MULTIMODAL_REGISTRY

57
    + @MULTIMODAL_REGISTRY.register_image_input_mapper()
58
      class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
59
60
61
62
63
64
65

A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.

.. seealso::
    :ref:`input_processing_pipeline`


66
67
3. Register maximum number of multi-modal tokens
------------------------------------------------
68

69
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item
70
71
72
73
74
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.

.. code-block:: diff

      from vllm.inputs import INPUT_REGISTRY
75
      from vllm.model_executor.models.interfaces import SupportsMultiModal
76
77
78
79
80
      from vllm.multimodal import MULTIMODAL_REGISTRY

      @MULTIMODAL_REGISTRY.register_image_input_mapper()
    + @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
      @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
81
      class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
82
83
84
85
86
87
88
89
90
91
92

Here are some examples:

- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__

.. seealso::
    :ref:`input_processing_pipeline`


4. (Optional) Register dummy data
93
94
95
96
97
98
99
---------------------------------

During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.

.. code-block:: diff

100
      from vllm.inputs import INPUT_REGISTRY
101
      from vllm.model_executor.models.interfaces import SupportsMultiModal
102
      from vllm.multimodal import MULTIMODAL_REGISTRY
103

104
105
      @MULTIMODAL_REGISTRY.register_image_input_mapper()
      @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
106
    + @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
107
      class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
108
109
110

.. note::
    The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
111
112
113
114
115
116
117
118
119
120

Here are some examples:

- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__

.. seealso::
    :ref:`input_processing_pipeline`


121
5. (Optional) Register input processor
122
123
124
125
126
127
128
129
--------------------------------------

Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor. 
This is often due to the fact that unlike implementations in HuggingFace Transformers, the reshaping and/or expansion of multi-modal embeddings needs to take place outside model's :meth:`~torch.nn.Module.forward` call.
You can register input processors via :meth:`INPUT_REGISTRY.register_input_processor <vllm.inputs.registry.InputRegistry.register_input_processor>`.

.. code-block:: diff

130
      from vllm.inputs import INPUT_REGISTRY
131
      from vllm.model_executor.models.interfaces import SupportsMultiModal
132
      from vllm.multimodal import MULTIMODAL_REGISTRY
133

134
135
136
      @MULTIMODAL_REGISTRY.register_image_input_mapper()
      @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
      @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
137
    + @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
138
      class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
139
140
141
142
143
144
145
146
147

A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples:

- Insert static number of image tokens: `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Insert dynamic number of image tokens: `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__

.. seealso::
    :ref:`input_processing_pipeline`