Unverified Commit 09194b90 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Doc] Update docs for MM model development with context usage (#32691)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 9ab4388c
...@@ -23,29 +23,32 @@ Further update the model as follows: ...@@ -23,29 +23,32 @@ Further update the model as follows:
raise ValueError("Only image modality is supported") raise ValueError("Only image modality is supported")
``` ```
- Reserve a keyword parameter in [forward][torch.nn.Module.forward] for each input tensor that corresponds to a multi-modal input, as shown in the following example: - Inside `__init__` method, initialize the language components of the model inside [_mark_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal._mark_language_model], and the multimodal components of the model inside [_mark_tower_model][vllm.model_executor.models.interfaces.SupportsMultiModal._mark_tower_model], e.g.:
```diff ```python
def forward( def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self, super().__init__()
input_ids: torch.Tensor,
positions: torch.Tensor, config = vllm_config.model_config.hf_config
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
```
More conveniently, you can simply pass `**kwargs` to the [forward][torch.nn.Module.forward] method and retrieve the keyword parameters for multimodal inputs from it. with self._mark_tower_model(vllm_config, "image"):
self.vision_encoder = ...
self.multi_modal_projector = ...
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
```
- Implement [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs. - Implement [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
??? code ??? code
```python ```python
class YourModelForImage2Seq(nn.Module):
...
def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor: def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:
assert self.vision_encoder is not None
image_features = self.vision_encoder(image_input) image_features = self.vision_encoder(image_input)
return self.multi_modal_projector(image_features) return self.multi_modal_projector(image_features)
...@@ -73,17 +76,6 @@ Further update the model as follows: ...@@ -73,17 +76,6 @@ Further update the model as follows:
You may override this method if additional logic is required for your model when merging embeddings. You may override this method if additional logic is required for your model when merging embeddings.
- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model.
```python
class YourModelForImage2Seq(nn.Module):
...
def get_language_model(self) -> torch.nn.Module:
# Change `language_model` according to your implementation.
return self.language_model
```
- Once the above steps are done, update the model class with the [SupportsMultiModal][vllm.model_executor.models.interfaces.SupportsMultiModal] interface. - Once the above steps are done, update the model class with the [SupportsMultiModal][vllm.model_executor.models.interfaces.SupportsMultiModal] interface.
```diff ```diff
......
...@@ -38,7 +38,7 @@ Encoder engines should be launched with the following flags: ...@@ -38,7 +38,7 @@ Encoder engines should be launched with the following flags:
- `--max-num-batched-tokens=<large value>` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager. - `--max-num-batched-tokens=<large value>` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager.
- `--mm-encoder-only` **(Optional)** - The language model is skipped during initialization to reduce device memory usage. **Models using this option must initialize the language component inside the context of `SupportsMultiModal._mark_language_model`.** - `--mm-encoder-only` **(Optional)** - If possible, skips the language model during initialization to reduce device memory usage.
## Local media inputs ## Local media inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment