Unverified Commit 9499e26e authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

[Model] Support VLMs with transformers backend (#20543)


Signed-off-by: default avatarraushan <raushan@huggingface.co>
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 51ba8395
......@@ -18,7 +18,7 @@ These models are what we list in [supported-text-models][supported-text-models]
### Transformers
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models are supported, and vision language model support is planned!
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs, and require setting `--disable_mm_preprocessor_cache` when running. Support for video inputs and caching of multi-modal preprocessors will be added in future releases.
To check if the modeling backend is Transformers, you can simply do this:
......@@ -28,7 +28,7 @@ llm = LLM(model=..., task="generate") # Name or path of your model
llm.apply_model(lambda model: print(type(model)))
```
If it is `TransformersForCausalLM` then it means it's based on Transformers!
If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers!
!!! tip
You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference](../serving/offline_inference.md) or `--model-impl transformers` for the [openai-compatible-server](../serving/openai_compatible_server.md).
......@@ -36,6 +36,9 @@ If it is `TransformersForCausalLM` then it means it's based on Transformers!
!!! note
vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM.
!!! note
In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance.
#### Custom models
If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM!
......@@ -99,7 +102,7 @@ Here is what happens in the background when this model is loaded:
1. The config is loaded.
2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`.
3. `MyModel` is loaded into `TransformersForCausalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
3. `MyModel` is loaded into `TransformersForCausalLM` or `TransformersForMultimodalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
That's it!
......
......@@ -35,6 +35,8 @@ if current_platform.is_rocm():
REQUIRES_V0_MODELS = [
# V1 Test: not enough KV cache space in C1.
"fuyu",
# V1 Test: Deadlock issue when processing mm_inputs
"llava-onevision-transformers",
]
# yapf: disable
......@@ -170,6 +172,79 @@ VLM_TEST_SETTINGS = {
hf_output_post_proc=model_utils.ultravox_trunc_hf_output,
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
#### Transformers fallback to test
## To reduce test burden, we only test batching arbitrary image size
# Dynamic image length and number of patches
"llava-onevision-transformers": VLMTestInfo(
models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"],
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
max_model_len=16384,
hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
image_size_factors=[(0.25, 0.5, 1.0)],
vllm_runner_kwargs={
"model_impl": "transformers",
"disable_mm_preprocessor_cache": True,
"enable_prefix_caching": False,
},
marks=[pytest.mark.core_model],
),
# FIXME(Isotr0py): Enable this test after
# https://github.com/huggingface/transformers/pull/39470 released
# "idefics3-transformers": VLMTestInfo(
# models=["HuggingFaceTB/SmolVLM-256M-Instruct"],
# test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
# prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
# img_idx_to_prompt=lambda idx: "<image>",
# max_model_len=8192,
# max_num_seqs=2,
# auto_cls=AutoModelForImageTextToText,
# hf_output_post_proc=model_utils.idefics3_trunc_hf_output,
# image_size_factors=[(0.25, 0.5, 1.0)],
# vllm_runner_kwargs={
# "model_impl": "transformers",
# "disable_mm_preprocessor_cache": True,
# "enable_prefix_caching": False,
# },
# marks=[pytest.mark.core_model],
# ),
# Pixel values from processor are not 4D or 5D arrays
"qwen2_5_vl-transformers": VLMTestInfo(
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
image_size_factors=[(0.25, 0.2, 0.15)],
vllm_runner_kwargs={
"model_impl": "transformers",
"disable_mm_preprocessor_cache": True,
"enable_prefix_caching": False,
},
marks=[large_gpu_mark(min_gb=32)],
),
# Check "auto" with fallback to transformers
"internvl-transformers": VLMTestInfo(
models=["OpenGVLab/InternVL3-1B-hf"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<IMG_CONTEXT>",
max_model_len=4096,
use_tokenizer_eos=True,
image_size_factors=[(0.25, 0.5, 1.0)],
vllm_runner_kwargs={
"model_impl": "auto",
"disable_mm_preprocessor_cache": True,
"enable_prefix_caching": False,
},
auto_cls=AutoModelForImageTextToText,
marks=[pytest.mark.core_model],
),
#### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
......
......@@ -499,6 +499,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
_TRANSFORMERS_MODELS = {
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
}
_EXAMPLE_MODELS = {
......
......@@ -562,6 +562,10 @@ class ModelConfig:
self.task = "embed"
model_info, arch = self.registry.inspect_model_cls(self.architectures)
self._model_info = model_info
self._architecture = arch
all_supported_tasks = self._get_supported_tasks(self.task)
logger.debug("Tasks supported by runner type: %s", all_supported_tasks)
supported_runner_types = self._get_supported_runner_types(
......@@ -587,10 +591,6 @@ class ModelConfig:
else:
self.truncation_side = "right"
model_info, arch = self.registry.inspect_model_cls(self.architectures)
self._model_info = model_info
self._architecture = arch
self.pooler_config = self._init_pooler_config()
self.dtype = _get_and_verify_dtype(
......@@ -674,6 +674,16 @@ class ModelConfig:
"max_model_len must be an integer after __post_init__.")
return self
def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal
return "TransformersForMultimodalLM"
else:
return "TransformersForCausalLM"
@property
def registry(self):
return me_models.ModelRegistry
......@@ -681,7 +691,19 @@ class ModelConfig:
@property
def architectures(self) -> list[str]:
# architectures in the model config.
return getattr(self.hf_config, "architectures", [])
architectures = getattr(self.hf_config, "architectures", [])
# The registry assumes that it can always inspect the vLLM model class
# for a given architecture. This assumption breaks down for the
# Transformers backend, which may use a different class depending on
# the model type. To work around this, we add the correct Transformers
# backend class to the architectures list. We must do this here because
# we need access to the `hf_config` to determine the backend class.
transformers_backend_cls = self._get_transformers_backend_cls()
if (self.model_impl != ModelImpl.VLLM.value
and all(arch != transformers_backend_cls
for arch in architectures)):
architectures.append(transformers_backend_cls)
return architectures
@property
def architecture(self) -> str:
......@@ -827,10 +849,9 @@ class ModelConfig:
("EmbeddingModel", "embed"),
("RewardModel", "reward"),
]
_, arch = self.registry.inspect_model_cls(architectures)
for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix):
if self.architecture.endswith(suffix):
return pref_task
return "embed"
......@@ -944,10 +965,10 @@ class ModelConfig:
("EmbeddingModel", "pooling"),
("RewardModel", "pooling"),
]
_, arch = self.registry.inspect_model_cls(self.architectures)
for suffix, pref_runner in suffix_to_preferred_runner:
if arch.endswith(suffix) and pref_runner in supported_runner_types:
if self.architecture.endswith(
suffix) and pref_runner in supported_runner_types:
return pref_runner
if "generate" in supported_runner_types:
......
......@@ -25,6 +25,7 @@ from vllm.model_executor.models.adapters import (as_embedding_model,
as_reward_model,
as_seq_cls_model)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.model_executor.models.registry import _TRANSFORMERS_MODELS
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
......@@ -169,9 +170,22 @@ def device_loading_context(module: torch.nn.Module,
def resolve_transformers_arch(model_config: ModelConfig,
architectures: list[str]):
if model_config.model_impl == ModelImpl.VLLM:
raise ValueError(
"Attempting to resolve architecture from the Transformers library "
"but the model implementation is set to vLLM. This should never "
"happen.")
for i, arch in enumerate(architectures):
if arch == "TransformersForCausalLM":
if arch in _TRANSFORMERS_MODELS:
continue
if model_config.model_impl == ModelImpl.AUTO:
logger.warning(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.", arch)
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
None) or dict()
# Make sure that config class is always initialized before model class,
......@@ -199,25 +213,13 @@ def resolve_transformers_arch(model_config: ModelConfig,
"not present in the model config's 'auto_map' (relevant "
"if the model is custom).")
model_module = auto_modules["AutoModel"]
# TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not model_module.is_backend_compatible():
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersForCausalLM"
if model_config.model_impl == ModelImpl.AUTO:
if not model_module.is_backend_compatible():
raise ValueError(
f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM. Try setting "
"VLLM_USE_V1=0.")
logger.warning(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.", arch)
architectures[i] = "TransformersForCausalLM"
if not model_module.is_backend_compatible():
raise ValueError(
f"The Transformers implementation of '{arch}' is not "
"compatible with vLLM.")
architectures[i] = model_config._get_transformers_backend_cls()
return architectures
......@@ -237,8 +239,9 @@ def get_model_architecture(
]
vllm_supported_archs = ModelRegistry.get_supported_archs()
vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures)
is_supported = lambda arch: (arch in vllm_supported_archs and arch not in
_TRANSFORMERS_MODELS)
vllm_not_supported = not any(is_supported(arch) for arch in architectures)
if vllm_not_supported:
# try automatic conversion in adapters.py
......@@ -259,7 +262,7 @@ def get_model_architecture(
break
if (model_config.model_impl == ModelImpl.TRANSFORMERS or
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
model_config.model_impl == ModelImpl.AUTO and vllm_not_supported):
architectures = resolve_transformers_arch(model_config, architectures)
logger.debug_once("Resolve transformers arch %s", str(architectures))
elif (model_config.quantization is not None
......
......@@ -253,6 +253,7 @@ _SPECULATIVE_DECODING_MODELS = {
}
_TRANSFORMERS_MODELS = {
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
}
# yapf: enable
......@@ -504,9 +505,14 @@ class _ModelRegistry:
if causal_lm_arch in self.models:
normalized_arch.append(arch)
# make sure Transformers backend is put at the last as a fallback
if len(normalized_arch) != len(architectures):
normalized_arch.append("TransformersForCausalLM")
# NOTE(Isotr0py): Be careful of architectures' order!
# Make sure Transformers backend architecture is at the end of the
# list, otherwise pooling models automatic conversion will fail!
for arch in normalized_arch:
if arch.startswith("TransformersFor"):
normalized_arch.remove(arch)
normalized_arch.append(arch)
return normalized_arch
def inspect_model_cls(
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment