Unverified Commit 87b4d155 authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend...


[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend of QwenVisionAttention with it. (#30125)
Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
parent 84e23d10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Consolidated test for ViT attention backend functionality across multiple models.
This test validates that each multimodal model can successfully generate outputs
using different ViT attention backends. Tests are parametrized by model and backend.
"""
from dataclasses import asdict
from typing import Any
import pytest
from transformers import AutoProcessor
from vllm import LLM, EngineArgs, SamplingParams
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.multimodal.utils import encode_image_base64
from vllm.multimodal.video import sample_frames_from_video
from vllm.platforms import current_platform
from ....utils import create_new_process_for_each_test
from ...utils import dummy_hf_overrides
# Dots.OCR prompt from official repository
# https://github.com/rednote-hilab/dots.ocr/blob/d72d1d8c5bdd0362eb264f714cdbd1e5daa7cdff/dots_ocr/utils/prompts.py#L3
# ruff: noqa: E501
DOTS_OCR_PROMPT = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules:
- Picture: For the 'Picture' category, the text field should be omitted.
- Formula: Format its text as LaTeX.
- Table: Format its text as HTML.
- All Others (Text, Title, etc.): Format their text as Markdown.
4. Constraints:
- The output text must be the original text from the image, with no translation.
- All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object.
"""
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
# Model configurations
MODEL_CONFIGS: dict[str, dict[str, Any]] = {
"dots_ocr": {
"model_name": "rednote-hilab/dots.ocr",
"interface": "llm_chat",
"max_model_len": 32768,
"max_num_seqs": 1,
"limit_mm_per_prompt": {"image": 1},
"sampling_params": {
"temperature": 0.1,
"max_tokens": 16384,
"top_p": 0.9,
"stop_token_ids": None,
},
"use_specific_image": "stop_sign",
"prompt_builder": "build_dots_ocr_prompt",
"output_validator": lambda x: len(x) > 10 and "stop" in x.lower(),
},
"ernie45_vl": {
"model_name": "baidu/ERNIE-4.5-VL-28B-A3B-PT",
"interface": "llm_generate",
"max_model_len": 16384,
"max_num_seqs": 2,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"use_processor": True,
"question": "What is the content of each image?",
},
"glm4_1v": {
"model_name": "zai-org/GLM-4.1V-9B-Thinking",
"interface": "llm_generate",
"max_model_len": 32768,
"max_num_seqs": 2,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"use_processor": True,
"question": "What is the content of each image?",
},
"keye_vl": {
"model_name": "Kwai-Keye/Keye-VL-8B-Preview",
"interface": "llm_generate",
"max_model_len": 8192,
"max_num_seqs": 5,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"supported_backends": {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
},
"use_processor": True,
"question": "What is the content of each image?",
},
"ovis2_5": {
"model_name": "AIDC-AI/Ovis2.5-2B",
"interface": "llm_generate",
"max_model_len": 8192,
"max_num_seqs": 2,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"prompt_builder": "build_ovis_prompt",
"question": "What is the content of each image?",
},
"qwen2_5_vl": {
"model_name": "Qwen/Qwen2.5-VL-3B-Instruct",
"interface": "vllm_runner",
"media_type": "video",
"max_model_len": 4000,
"max_num_seqs": 1,
"limit_mm_per_prompt": {"video": 1},
"sampling_params": {
"max_tokens": 128,
},
"runner_kwargs": {
"runner": "generate",
"dtype": "bfloat16",
},
"video_params": {
"num_frames": 16,
"pruning_rates": [0.0, 0.75],
},
},
"qwen2_5_omni": {
"model_name": "Qwen/Qwen2.5-Omni-3B",
"interface": "llm_generate",
"max_model_len": 32768,
"max_num_seqs": 2,
"limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3},
"sampling_params": {
"temperature": 0.6,
"top_p": 0.95,
"top_k": 20,
"max_tokens": 16384,
},
"use_processor": True,
"question": "What is the content of each image?",
},
"qwen3_omni": {
"model_name": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
"interface": "llm_generate",
"max_model_len": 32768,
"max_num_seqs": 2,
"limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3},
"sampling_params": {
"temperature": 0.6,
"top_p": 0.95,
"top_k": 20,
"max_tokens": 16384,
},
"use_processor": True,
"question": "What is the content of each image?",
},
}
# Prompt builder functions
def build_dots_ocr_prompt(images, config):
"""Build Dots.OCR specific prompt with OCR instructions."""
# Use only stop_sign image for Dots.OCR
image = images[0] # Already filtered to stop_sign
image_url = f"data:image/jpeg;base64,{encode_image_base64(image)}"
placeholders = [{"type": "image_url", "image_url": {"url": image_url}}]
messages = [
{
"role": "user",
"content": [
*placeholders,
{
"type": "text",
"text": f"<|img|><|imgpad|><|endofimg|>{DOTS_OCR_PROMPT}",
},
],
},
]
return messages
def build_processor_prompt(images, config):
"""Build prompt using AutoProcessor.apply_chat_template()."""
processor = AutoProcessor.from_pretrained(
config["model_name"], trust_remote_code=True
)
image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
]
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": config["question"]},
],
},
]
return processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
def build_ovis_prompt(images, config):
"""Build Ovis2.5 specific prompt with custom format."""
image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
]
placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
return (
f"<|im_start|>user\n\n{placeholders}\n{config['question']}<|im_end|>\n"
"<|im_start|>assistant\n"
)
def build_qwen2_5_video_prompt():
"""Build Qwen2.5-VL video prompt with EVS placeholder."""
return (
f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{VIDEO_PLACEHOLDER}"
"Describe this video with a short sentence (no more than 20 words)"
"<|im_end|><|im_start|>assistant\n"
)
# Handler functions
def run_llm_generate_test(config, mm_encoder_attn_backend, image_assets):
"""Standard LLM.generate() interface handler."""
images = [asset.pil_image for asset in image_assets]
# Build prompt
if config.get("use_processor"):
prompt = build_processor_prompt(images, config)
else:
prompt_builder_name = config.get("prompt_builder", "build_ovis_prompt")
prompt_builder = globals()[prompt_builder_name]
prompt = prompt_builder(images, config)
# Determine limit_mm_per_prompt
limit_mm_per_prompt = config.get("limit_mm_per_prompt", {"image": len(images)})
# Create engine
engine_args = EngineArgs(
model=config["model_name"],
trust_remote_code=True,
max_model_len=config["max_model_len"],
max_num_seqs=config["max_num_seqs"],
limit_mm_per_prompt=limit_mm_per_prompt,
mm_encoder_attn_backend=mm_encoder_attn_backend,
hf_overrides=dummy_hf_overrides,
load_format="dummy",
)
engine_dict = asdict(engine_args) | {"seed": 42}
llm = LLM(**engine_dict)
# Generate
sampling_params = SamplingParams(**config["sampling_params"])
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {"image": images},
},
sampling_params=sampling_params,
)
# Validate
for o in outputs:
generated_text = o.outputs[0].text
validator = config.get("output_validator", lambda x: len(x) > 10)
assert validator(generated_text), (
f"Validation failed for {config['model_name']}: {generated_text}"
)
def run_llm_chat_test(config, mm_encoder_attn_backend, image_assets):
"""LLM.chat() interface handler for Dots.OCR."""
# Filter to stop_sign image only
stop_sign_image = [
asset.pil_image for asset in image_assets if asset.name == "stop_sign"
][0]
# Build messages
messages = build_dots_ocr_prompt([stop_sign_image], config)
# Create engine
engine_args = EngineArgs(
model=config["model_name"],
trust_remote_code=True,
max_model_len=config["max_model_len"],
max_num_seqs=config["max_num_seqs"],
limit_mm_per_prompt=config["limit_mm_per_prompt"],
mm_encoder_attn_backend=mm_encoder_attn_backend,
hf_overrides=dummy_hf_overrides,
load_format="dummy",
)
engine_dict = asdict(engine_args) | {"seed": 42}
llm = LLM(**engine_dict)
# Generate using chat
sampling_params = SamplingParams(**config["sampling_params"])
outputs = llm.chat(messages=messages, sampling_params=sampling_params)
# Validate
for o in outputs:
generated_text = o.outputs[0].text
validator = config.get("output_validator", lambda x: len(x) > 10)
assert validator(generated_text), (
f"Validation failed for {config['model_name']}: {generated_text}"
)
def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
"""Video test with EVS (Efficient Video Sampling) handler."""
for pruning_rate in config["video_params"]["pruning_rates"]:
num_frames = config["video_params"]["num_frames"]
# Sample frames from video
sampled_vids = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
for asset in video_assets
]
# Build prompt and prepare video
prompt = build_qwen2_5_video_prompt()
prompts = [prompt]
videos = [sampled_vids[0]]
# Run with vllm_runner context manager
with vllm_runner(
config["model_name"],
max_model_len=config["max_model_len"],
max_num_seqs=config["max_num_seqs"],
limit_mm_per_prompt=config["limit_mm_per_prompt"],
tensor_parallel_size=1,
video_pruning_rate=pruning_rate,
mm_encoder_attn_backend=mm_encoder_attn_backend,
hf_overrides=dummy_hf_overrides,
load_format="dummy",
**config["runner_kwargs"],
) as vllm_model:
outputs = vllm_model.generate_greedy(
prompts,
config["sampling_params"]["max_tokens"],
videos=videos,
)
# Validate output
assert len(outputs) == 1, f"Expected 1 output, got {len(outputs)}"
output_ids, output_text = outputs[0]
assert len(output_ids) > 0, "Generated no output IDs"
assert len(output_text) > 0, "Generated empty text"
assert isinstance(output_text, str), (
f"Output is not string: {type(output_text)}"
)
# Main test function
@pytest.mark.parametrize("model_key", list(MODEL_CONFIGS.keys()))
@pytest.mark.parametrize(
"mm_encoder_attn_backend",
[None] + current_platform.get_supported_vit_attn_backends(),
)
@create_new_process_for_each_test()
def test_vit_backend_functionality(
model_key: str,
mm_encoder_attn_backend: AttentionBackendEnum | None,
image_assets,
video_assets,
vllm_runner,
request,
):
"""Test ViT attention backend functionality for multimodal models.
This test validates that each model can successfully generate outputs
using different ViT attention backends. The test:
1. Filters unsupported backends per model
2. Applies appropriate GPU marks
3. Routes to the correct test handler based on interface
4. Validates output meets minimum requirements
"""
config = MODEL_CONFIGS[model_key]
# Step 1: Backend filtering
if (
"supported_backends" in config
and mm_encoder_attn_backend is not None
and mm_encoder_attn_backend not in config["supported_backends"]
):
pytest.skip(
f"{model_key} does not support {mm_encoder_attn_backend} backend now."
)
# Step 2: Apply GPU marks dynamically
if "gpu_marks" in config:
for mark in config["gpu_marks"]:
request.applymarker(mark)
# Step 3: Route to appropriate handler
if config.get("media_type") == "video":
run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner)
elif config["interface"] == "llm_chat":
run_llm_chat_test(config, mm_encoder_attn_backend, image_assets)
elif config["interface"] == "llm_generate":
run_llm_generate_test(config, mm_encoder_attn_backend, image_assets)
else:
raise ValueError(f"Unknown interface: {config['interface']}")
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
"""Attention layer.""" """Attention layer."""
import functools import functools
from collections.abc import Callable
from typing import cast from typing import cast
import torch import torch
...@@ -17,6 +16,7 @@ from vllm.attention.backends.abstract import ( ...@@ -17,6 +16,7 @@ from vllm.attention.backends.abstract import (
MLAAttentionImpl, MLAAttentionImpl,
) )
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
...@@ -49,58 +49,9 @@ from vllm.v1.kv_cache_interface import ( ...@@ -49,58 +49,9 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec, SlidingWindowSpec,
) )
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9
else:
on_gfx9 = lambda *args, **kwargs: False
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__) logger = init_logger(__name__)
def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum,
attn_backend_override: AttentionBackendEnum | None = None,
) -> tuple[AttentionBackendEnum, Callable | None]:
if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
elif (
attn_backend_override is None
and on_gfx9()
and attn_backend == AttentionBackendEnum.FLASH_ATTN
):
pass
else:
return AttentionBackendEnum.TORCH_SDPA, None
elif current_platform.is_cuda():
pass
elif current_platform.is_xpu():
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend."
)
pass
else:
return AttentionBackendEnum.TORCH_SDPA, None
if attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
try:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
else:
flash_attn_varlen_func = None
return attn_backend, flash_attn_varlen_func
def _init_kv_cache_quant( def _init_kv_cache_quant(
layer: nn.Module, layer: nn.Module,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
...@@ -496,29 +447,15 @@ class MultiHeadAttention(nn.Module): ...@@ -496,29 +447,15 @@ class MultiHeadAttention(nn.Module):
attn_backend_override = None attn_backend_override = None
if multimodal_config is not None: if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend attn_backend_override = multimodal_config.mm_encoder_attn_backend
backend = get_vit_attn_backend(
self.attn_backend = get_vit_attn_backend(
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
self.attn_backend = ( self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
backend
if backend
in {
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.PALLAS,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
}
else AttentionBackendEnum.TORCH_SDPA
)
self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
attn_backend_override=attn_backend_override,
)
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
from vllm.config import MultiModalConfig
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.models.vision import get_vit_attn_backend
logger = init_logger(__name__)
def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum | None,
) -> Callable | None:
# At this point,
# we already have the attn_backend,
# overriding logic is done in the platform-specific implementation.
# so we don't need to override backend here.
# Just return the attn_backend and flash_attn_varlen_func.
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
# if attn_backend is TORCH_SDPA,
# it will reach here and the flash_attn_varlen_func will be None.
return flash_attn_varlen_func
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float | None = None,
num_kv_heads: int | None = None,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
"""
Args:
num_heads: number of attention heads per partition.
head_size: hidden_size per attention head.
scale: scale factor.
num_kv_heads: number of kv heads.
prefix: This has no effect, it is only here to make it easier to
swap between Attention and MultiHeadAttention
multimodal_config: configs for multi-modal.
"""
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
# Try to get vision attention backend from multimodal_config.
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
# Get device-specific vision attention backend.
self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
)
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
@classmethod
def enabled(cls) -> bool:
return True
def reshape_qkv_to_4d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 4D tensors:
(batch_size, seq_len, num_heads, head_size)
"""
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
return query, key, value
def reshape_qkv_to_3d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 3D tensors:
(batch_size * seq_len, num_heads, head_size)
"""
query = query.view(bsz * q_len, self.num_heads, self.head_size)
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=1)
value = torch.repeat_interleave(value, num_repeat, dim=1)
return query, key, value
def _forward_sdpa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
# TODO(Isotr0py): Migrate MultiHeadAttention
assert cu_seqlens is not None
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
query, key, value = self.reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)
output = vit_torch_sdpa_wrapper(
q=query,
k=key,
v=value,
cu_seqlens=cu_seqlens,
)
return output
def _forward_fa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.flash_attn_varlen_func is not None, (
"Flash attention function is not set."
)
# # TODO(Isotr0py): Migrate MultiHeadAttention
assert cu_seqlens is not None and max_seqlen is not None
bsz = query.shape[0]
output = vit_flash_attn_wrapper(
q=query,
k=key,
v=value,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
)
return output
def forward_native(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
if self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
raise ValueError(
f"Unsupported multi-modal encoder attention backend for CUDA: "
f"{self.attn_backend}."
)
def forward_cpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
def forward_xpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.is_flash_attn_backend, (
"XPU only supports FLASH_ATTN for vision attention."
)
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
def forward_tpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.attn_backend == AttentionBackendEnum.PALLAS, (
f"MMEncoderAttention on TPU only supports PALLAS backend, "
f"but got {self.attn_backend}."
)
if cu_seqlens is None:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
return out
logger.warning_once(
"PALLAS backend with cu_seqlens is not supported for ViT yet. ",
"Falling back to SDPA implementation.",
)
return self._forward_sdpa(query, key, value, cu_seqlens)
...@@ -44,9 +44,7 @@ def flash_attn_maxseqlen_wrapper( ...@@ -44,9 +44,7 @@ def flash_attn_maxseqlen_wrapper(
dropout_p=0.0, dropout_p=0.0,
causal=False, causal=False,
) )
context_layer = einops.rearrange( context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
return context_layer return context_layer
...@@ -59,8 +57,7 @@ def flash_attn_maxseqlen_wrapper_fake( ...@@ -59,8 +57,7 @@ def flash_attn_maxseqlen_wrapper_fake(
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
) -> torch.Tensor: ) -> torch.Tensor:
b, s, h, d = q.shape return torch.empty_like(q)
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op( direct_register_custom_op(
...@@ -106,7 +103,6 @@ def torch_sdpa_wrapper( ...@@ -106,7 +103,6 @@ def torch_sdpa_wrapper(
output_i = einops.rearrange(output_i, "b h s d -> b s h d ") output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i) outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) context_layer = torch.cat(outputs, dim=1)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer return context_layer
...@@ -116,8 +112,7 @@ def torch_sdpa_wrapper_fake( ...@@ -116,8 +112,7 @@ def torch_sdpa_wrapper_fake(
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
b, s, h, d = q.shape return torch.empty_like(q)
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op( direct_register_custom_op(
......
...@@ -5,15 +5,14 @@ from typing import Annotated, Literal, TypeAlias ...@@ -5,15 +5,14 @@ from typing import Annotated, Literal, TypeAlias
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layers.mm_encoder_attention import (
maybe_get_vit_flash_attn_backend, MMEncoderAttention,
) )
from vllm.config import VllmConfig from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
...@@ -254,11 +253,15 @@ class DotsVisionAttention(nn.Module): ...@@ -254,11 +253,15 @@ class DotsVisionAttention(nn.Module):
bias: bool = True, bias: bool = True,
*, *,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.embed_dim = dim self.embed_dim = dim
self.tp_size = ( self.tp_size = (
...@@ -287,31 +290,13 @@ class DotsVisionAttention(nn.Module): ...@@ -287,31 +290,13 @@ class DotsVisionAttention(nn.Module):
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
disable_tp=use_data_parallel, disable_tp=use_data_parallel,
) )
# Select attention backend
self.attn_backend = get_vit_attn_backend(
self.hidden_size_per_attention_head,
torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.attn_backend, self.flash_attn_varlen_func = ( self.attn = MMEncoderAttention(
maybe_get_vit_flash_attn_backend( num_heads=self.num_attention_heads_per_partition,
self.attn_backend, head_size=self.hidden_size_per_attention_head,
attn_backend_override=attn_backend_override, multimodal_config=multimodal_config,
) prefix=f"{prefix}.attn",
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Unsupported vision attention backend: {self.attn_backend}"
) )
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward( def forward(
self, self,
...@@ -319,7 +304,7 @@ class DotsVisionAttention(nn.Module): ...@@ -319,7 +304,7 @@ class DotsVisionAttention(nn.Module):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None = None, rotary_pos_emb: torch.Tensor | None = None,
*, *,
max_seqlen: int | None = None, max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
# [S, C] -> [S, B=1, C] # [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1) x = hidden_states.unsqueeze(1)
...@@ -336,41 +321,13 @@ class DotsVisionAttention(nn.Module): ...@@ -336,41 +321,13 @@ class DotsVisionAttention(nn.Module):
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend: context_layer = self.attn(
q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3]) query=q,
k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3]) key=k,
v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3]) value=v,
output = self.flash_attn_varlen_func( cu_seqlens=cu_seqlens,
q_, max_seqlen=max_seqlen,
k_, )
v_,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
context_layer = output.view(
bs,
-1,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
s = int(cu_seqlens[i - 1])
e = int(cu_seqlens[i])
q_i = q[:, s:e].permute(0, 2, 1, 3)
k_i = k[:, s:e].permute(0, 2, 1, 3)
v_i = v[:, s:e].permute(0, 2, 1, 3)
out_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
out_i = out_i.permute(0, 2, 1, 3)
outputs.append(out_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
else:
raise RuntimeError("Unsupported attention backend")
# [B,S,H,D] -> [S,B,H*D] -> [S, C] # [B,S,H,D] -> [S,B,H*D] -> [S, C]
context_layer = context_layer.permute(1, 0, 2, 3).contiguous() context_layer = context_layer.permute(1, 0, 2, 3).contiguous()
...@@ -385,14 +342,19 @@ class DotsSwiGLUFFN(nn.Module): ...@@ -385,14 +342,19 @@ class DotsSwiGLUFFN(nn.Module):
config, config,
*, *,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
hidden_features = config.intermediate_size hidden_features = config.intermediate_size
in_features = config.embed_dim in_features = config.embed_dim
bias = config.use_bias bias = config.use_bias
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
# Referenced aimv2.py AIMv2SwiGLUFFN # Referenced aimv2.py AIMv2SwiGLUFFN
self.fc13 = MergedColumnParallelLinear( self.fc13 = MergedColumnParallelLinear(
in_features, in_features,
...@@ -498,9 +460,8 @@ class DotsVisionBlock(nn.Module): ...@@ -498,9 +460,8 @@ class DotsVisionBlock(nn.Module):
config, config,
*, *,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
...@@ -510,16 +471,15 @@ class DotsVisionBlock(nn.Module): ...@@ -510,16 +471,15 @@ class DotsVisionBlock(nn.Module):
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
bias=config.use_bias, bias=config.use_bias,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
) )
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
self.mlp = DotsSwiGLUFFN( self.mlp = DotsSwiGLUFFN(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
...@@ -546,12 +506,11 @@ class DotsVisionTransformer(nn.Module): ...@@ -546,12 +506,11 @@ class DotsVisionTransformer(nn.Module):
self, self,
config: DotsVisionConfig, config: DotsVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -561,6 +520,11 @@ class DotsVisionTransformer(nn.Module): ...@@ -561,6 +520,11 @@ class DotsVisionTransformer(nn.Module):
head_dim = config.embed_dim // config.num_attention_heads head_dim = config.embed_dim // config.num_attention_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
...@@ -578,9 +542,8 @@ class DotsVisionTransformer(nn.Module): ...@@ -578,9 +542,8 @@ class DotsVisionTransformer(nn.Module):
DotsVisionBlock( DotsVisionBlock(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{i}", prefix=f"{prefix}.blocks.{i}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
) )
for i in range(num_layers) for i in range(num_layers)
] ]
...@@ -592,6 +555,11 @@ class DotsVisionTransformer(nn.Module): ...@@ -592,6 +555,11 @@ class DotsVisionTransformer(nn.Module):
else: else:
self.post_trunk_norm = None self.post_trunk_norm = None
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.merger = PatchMerger( self.merger = PatchMerger(
dim=config.hidden_size, dim=config.hidden_size,
context_dim=config.embed_dim, context_dim=config.embed_dim,
...@@ -647,7 +615,7 @@ class DotsVisionTransformer(nn.Module): ...@@ -647,7 +615,7 @@ class DotsVisionTransformer(nn.Module):
self.attn_backend == AttentionBackendEnum.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
def forward( def forward(
...@@ -733,17 +701,12 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ...@@ -733,17 +701,12 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
self.config.vision_config = vision_config self.config.vision_config = vision_config
else: else:
vision_config = self.config.vision_config vision_config = self.config.vision_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.vision_tower = DotsVisionTransformer( self.vision_tower = DotsVisionTransformer(
vision_config, vision_config,
quant_config=self.quant_config, quant_config=self.quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
) )
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
......
...@@ -37,10 +37,10 @@ from einops import rearrange, repeat ...@@ -37,10 +37,10 @@ from einops import rearrange, repeat
from transformers import BatchFeature from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layers.mm_encoder_attention import (
maybe_get_vit_flash_attn_backend, MMEncoderAttention,
) )
from vllm.config import VllmConfig from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -163,8 +163,8 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -163,8 +163,8 @@ class Ernie4_5_VisionAttention(nn.Module):
num_heads: int, num_heads: int,
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
...@@ -193,32 +193,12 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -193,32 +193,12 @@ class Ernie4_5_VisionAttention(nn.Module):
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
) )
# Detect attention implementation. self.attn = MMEncoderAttention(
self.attn_backend = get_vit_attn_backend( num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(), multimodal_config=multimodal_config,
attn_backend_override=attn_backend_override, prefix=f"{prefix}.attn",
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Ernie45-VL does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim] # [s, b, 3 * head * head_dim]
...@@ -253,14 +233,13 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -253,14 +233,13 @@ class Ernie4_5_VisionAttention(nn.Module):
x: torch.Tensor, x: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor, rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor: ) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim] # [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x) x, _ = self.qkv(x)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x) q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
...@@ -268,43 +247,14 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -268,43 +247,14 @@ class Ernie4_5_VisionAttention(nn.Module):
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend: output = self.attn(
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) query=q,
key=k,
output = self.flash_attn_varlen_func( value=v,
q, cu_seqlens=cu_seqlens,
k, max_seqlen=max_seqlen,
v, )
cu_seqlens_q=cu_seqlens, context_layer = rearrange(output, "b s h d -> s b (h d)").contiguous()
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer) output, _ = self.proj(context_layer)
return output return output
...@@ -350,8 +300,8 @@ class Ernie4_5_VisionBlock(nn.Module): ...@@ -350,8 +300,8 @@ class Ernie4_5_VisionBlock(nn.Module):
act_layer: type[nn.Module] = QuickGELU, act_layer: type[nn.Module] = QuickGELU,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -366,8 +316,8 @@ class Ernie4_5_VisionBlock(nn.Module): ...@@ -366,8 +316,8 @@ class Ernie4_5_VisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
attn_backend_override=attn_backend_override,
) )
self.mlp = Ernie4_5_VisionMLP( self.mlp = Ernie4_5_VisionMLP(
...@@ -383,7 +333,7 @@ class Ernie4_5_VisionBlock(nn.Module): ...@@ -383,7 +333,7 @@ class Ernie4_5_VisionBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor, rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = hidden_states + self.attn( hidden_states = hidden_states + self.attn(
self.norm1(hidden_states), self.norm1(hidden_states),
...@@ -441,8 +391,8 @@ class Ernie4_5_VisionTransformer(nn.Module): ...@@ -441,8 +391,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
vision_config, vision_config,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
...@@ -477,8 +427,8 @@ class Ernie4_5_VisionTransformer(nn.Module): ...@@ -477,8 +427,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
attn_backend_override=attn_backend_override,
) )
for layer_idx in range(depth) for layer_idx in range(depth)
] ]
...@@ -489,6 +439,9 @@ class Ernie4_5_VisionTransformer(nn.Module): ...@@ -489,6 +439,9 @@ class Ernie4_5_VisionTransformer(nn.Module):
) )
self.ln = nn.LayerNorm(hidden_size, eps=1e-6) self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
...@@ -535,13 +488,13 @@ class Ernie4_5_VisionTransformer(nn.Module): ...@@ -535,13 +488,13 @@ class Ernie4_5_VisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb return rotary_pos_emb
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None: def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
max_seqlen = None max_seqlen = None
if ( if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
def forward( def forward(
...@@ -1304,17 +1257,12 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1304,17 +1257,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.vision_model = Ernie4_5_VisionTransformer( self.vision_model = Ernie4_5_VisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
attn_backend_override=attn_backend_override,
) )
self.language_model = Ernie4_5_VLMoeForCausalLM( self.language_model = Ernie4_5_VLMoeForCausalLM(
......
...@@ -47,8 +47,10 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor ...@@ -47,8 +47,10 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.layers.mm_encoder_attention import (
from vllm.config import VllmConfig MMEncoderAttention,
)
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -191,10 +193,15 @@ class Glm4vVisionMLP(nn.Module): ...@@ -191,10 +193,15 @@ class Glm4vVisionMLP(nn.Module):
hidden_features: int, hidden_features: int,
bias: bool = False, bias: bool = False,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features, input_size=in_features,
output_sizes=[hidden_features] * 2, output_sizes=[hidden_features] * 2,
...@@ -248,12 +255,16 @@ class Glm4vVisionAttention(nn.Module): ...@@ -248,12 +255,16 @@ class Glm4vVisionAttention(nn.Module):
num_heads: int, num_heads: int,
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = ( self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size() 1 if use_data_parallel else get_tensor_model_parallel_world_size()
) )
...@@ -287,34 +298,12 @@ class Glm4vVisionAttention(nn.Module): ...@@ -287,34 +298,12 @@ class Glm4vVisionAttention(nn.Module):
disable_tp=use_data_parallel, disable_tp=use_data_parallel,
) )
# Detect attention implementation. self.attn = MMEncoderAttention(
self.attn_backend = get_vit_attn_backend( num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(), multimodal_config=multimodal_config,
attn_backend_override=attn_backend_override,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim] # [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape seq_len, bs, _ = qkv.shape
...@@ -338,14 +327,13 @@ class Glm4vVisionAttention(nn.Module): ...@@ -338,14 +327,13 @@ class Glm4vVisionAttention(nn.Module):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor, rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor: ) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim] # [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x) x, _ = self.qkv(x)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x) q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
...@@ -356,43 +344,14 @@ class Glm4vVisionAttention(nn.Module): ...@@ -356,43 +344,14 @@ class Glm4vVisionAttention(nn.Module):
) )
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend: context_layer = self.attn(
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) query=q,
key=k,
output = self.flash_attn_varlen_func( value=v,
q, cu_seqlens=cu_seqlens,
k, max_seqlen=max_seqlen,
v, )
cu_seqlens_q=cu_seqlens, context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer) output, _ = self.proj(context_layer)
return output return output
...@@ -406,9 +365,8 @@ class Glm4vVisionBlock(nn.Module): ...@@ -406,9 +365,8 @@ class Glm4vVisionBlock(nn.Module):
mlp_hidden_dim: int, mlp_hidden_dim: int,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
...@@ -420,17 +378,16 @@ class Glm4vVisionBlock(nn.Module): ...@@ -420,17 +378,16 @@ class Glm4vVisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
) )
self.mlp = Glm4vVisionMLP( self.mlp = Glm4vVisionMLP(
dim, dim,
mlp_hidden_dim, mlp_hidden_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
def forward( def forward(
...@@ -489,11 +446,16 @@ class Glm4vPatchMerger(nn.Module): ...@@ -489,11 +446,16 @@ class Glm4vPatchMerger(nn.Module):
d_model: int, d_model: int,
context_dim: int, context_dim: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
bias: bool = False, bias: bool = False,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = d_model self.hidden_size = d_model
self.proj = ColumnParallelLinear( self.proj = ColumnParallelLinear(
self.hidden_size, self.hidden_size,
...@@ -649,19 +611,19 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -649,19 +611,19 @@ class Glm4vVisionTransformer(nn.Module):
vision_config: Glm4vVisionConfig, vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
assert multimodal_config is not None, "multimodal_config must be provided"
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size temporal_patch_size = vision_config.temporal_patch_size
in_channels = vision_config.in_channels in_channels = vision_config.in_channels
depth = vision_config.depth depth = vision_config.depth
self.hidden_size = vision_config.hidden_size self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads self.num_heads = vision_config.num_heads
self.use_data_parallel = use_data_parallel
self.patch_size = vision_config.patch_size self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size self.spatial_merge_size = vision_config.spatial_merge_size
...@@ -690,9 +652,8 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -690,9 +652,8 @@ class Glm4vVisionTransformer(nn.Module):
mlp_hidden_dim=vision_config.out_hidden_size, mlp_hidden_dim=vision_config.out_hidden_size,
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
) )
for layer_idx in range(depth) for layer_idx in range(depth)
] ]
...@@ -701,9 +662,9 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -701,9 +662,9 @@ class Glm4vVisionTransformer(nn.Module):
d_model=vision_config.out_hidden_size, d_model=vision_config.out_hidden_size,
context_dim=vision_config.intermediate_size, context_dim=vision_config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
bias=False, bias=False,
prefix=f"{prefix}.merger", prefix=f"{prefix}.merger",
use_data_parallel=self.use_data_parallel,
) )
self.embeddings = Glm4vVisionEmbeddings(vision_config) self.embeddings = Glm4vVisionEmbeddings(vision_config)
...@@ -723,7 +684,7 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -723,7 +684,7 @@ class Glm4vVisionTransformer(nn.Module):
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=multimodal_config.mm_encoder_attn_backend,
) )
@property @property
...@@ -775,13 +736,13 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -775,13 +736,13 @@ class Glm4vVisionTransformer(nn.Module):
def compute_attn_mask_seqlen( def compute_attn_mask_seqlen(
self, self,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> int | None: ) -> torch.Tensor | None:
max_seqlen = None max_seqlen = None
if ( if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
def forward( def forward(
...@@ -1465,18 +1426,12 @@ class Glm4vForConditionalGeneration( ...@@ -1465,18 +1426,12 @@ class Glm4vForConditionalGeneration(
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Glm4vVisionTransformer( self.visual = Glm4vVisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5), norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
) )
if config.model_type == "glm4v": if config.model_type == "glm4v":
......
...@@ -9,7 +9,6 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar ...@@ -9,7 +9,6 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.activations import GELUActivation from transformers.activations import GELUActivation
...@@ -17,11 +16,10 @@ from transformers.feature_extraction_utils import BatchFeature ...@@ -17,11 +16,10 @@ from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layers.mm_encoder_attention import (
from vllm.attention.layer import ( MMEncoderAttention,
maybe_get_vit_flash_attn_backend,
) )
from vllm.config import VllmConfig from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -80,7 +78,6 @@ from .utils import ( ...@@ -80,7 +78,6 @@ from .utils import (
is_pp_missing_parameter, is_pp_missing_parameter,
maybe_prefix, maybe_prefix,
) )
from .vision import get_vit_attn_backend
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -369,8 +366,8 @@ class KeyeSiglipAttention(nn.Module): ...@@ -369,8 +366,8 @@ class KeyeSiglipAttention(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -408,34 +405,14 @@ class KeyeSiglipAttention(nn.Module): ...@@ -408,34 +405,14 @@ class KeyeSiglipAttention(nn.Module):
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
) )
# Detect attention implementation. self.attn = MMEncoderAttention(
self.attn_backend = get_vit_attn_backend( num_heads=self.num_heads,
head_size=self.head_dim, head_size=self.head_dim,
dtype=torch.get_default_dtype(), num_kv_heads=self.num_kv_heads,
attn_backend_override=attn_backend_override, prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
) )
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -450,8 +427,7 @@ class KeyeSiglipAttention(nn.Module): ...@@ -450,8 +427,7 @@ class KeyeSiglipAttention(nn.Module):
dim=-1, dim=-1,
) )
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
batch_size = q.shape[0]
if rope_emb is None: if rope_emb is None:
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
...@@ -482,38 +458,14 @@ class KeyeSiglipAttention(nn.Module): ...@@ -482,38 +458,14 @@ class KeyeSiglipAttention(nn.Module):
self.head_dim, self.head_dim,
) )
if self.is_flash_attn_backend: context_layer = self.attn(
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) query=q,
key=k,
output = self.flash_attn_varlen_func( value=v,
q, cu_seqlens=cu_seqlens,
k, max_seqlen=max_seqlen,
v, )
cu_seqlens_q=cu_seqlens, context_layer = rearrange(context_layer, "b s h d -> b s (h d)")
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=False,
softmax_scale=self.scale,
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
output, _ = self.out_proj(context_layer) output, _ = self.out_proj(context_layer)
return output return output
...@@ -547,8 +499,8 @@ class KeyeSiglipEncoderLayer(nn.Module): ...@@ -547,8 +499,8 @@ class KeyeSiglipEncoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -556,8 +508,8 @@ class KeyeSiglipEncoderLayer(nn.Module): ...@@ -556,8 +508,8 @@ class KeyeSiglipEncoderLayer(nn.Module):
self.self_attn = KeyeSiglipAttention( self.self_attn = KeyeSiglipAttention(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
attn_backend_override=attn_backend_override,
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
...@@ -601,8 +553,8 @@ class KeyeSiglipEncoder(nn.Module): ...@@ -601,8 +553,8 @@ class KeyeSiglipEncoder(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -614,8 +566,8 @@ class KeyeSiglipEncoder(nn.Module): ...@@ -614,8 +566,8 @@ class KeyeSiglipEncoder(nn.Module):
KeyeSiglipEncoderLayer( KeyeSiglipEncoderLayer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
attn_backend_override=attn_backend_override,
) )
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
] ]
...@@ -696,8 +648,8 @@ class KeyeSiglipVisionTransformer(nn.Module): ...@@ -696,8 +648,8 @@ class KeyeSiglipVisionTransformer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -707,8 +659,8 @@ class KeyeSiglipVisionTransformer(nn.Module): ...@@ -707,8 +659,8 @@ class KeyeSiglipVisionTransformer(nn.Module):
self.encoder = KeyeSiglipEncoder( self.encoder = KeyeSiglipEncoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_backend_override=attn_backend_override,
) )
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
...@@ -779,16 +731,16 @@ class KeyeSiglipVisionModel(nn.Module): ...@@ -779,16 +731,16 @@ class KeyeSiglipVisionModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.vision_model = KeyeSiglipVisionTransformer( self.vision_model = KeyeSiglipVisionTransformer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
attn_backend_override=attn_backend_override,
) )
self.quant_config = quant_config self.quant_config = quant_config
...@@ -1329,16 +1281,11 @@ class BaseKeyeModule(nn.Module): ...@@ -1329,16 +1281,11 @@ class BaseKeyeModule(nn.Module):
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = KeyeSiglipVisionModel( self.visual = KeyeSiglipVisionModel(
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
) )
self.mlp_AR = self._build_projector( self.mlp_AR = self._build_projector(
......
...@@ -240,18 +240,12 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -240,18 +240,12 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
) )
if multimodal_config.get_limit_per_prompt("image"): if multimodal_config.get_limit_per_prompt("image"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = OpenCUAVisionTransformer( self.visual = OpenCUAVisionTransformer(
vision_config=config.vision_config, vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config, quant_config=self.quant_config,
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
) )
else: else:
self.visual = None self.visual = None
......
...@@ -10,8 +10,7 @@ import torch ...@@ -10,8 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -104,18 +103,16 @@ class VisualTokenizer(torch.nn.Module): ...@@ -104,18 +103,16 @@ class VisualTokenizer(torch.nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
visual_vocab_size: int, visual_vocab_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.vit = self._init_backbone( self.vit = self._init_backbone(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vit", prefix=f"{prefix}.vit",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
) )
# reserved tokens for INDICATOR_IDS # reserved tokens for INDICATOR_IDS
head_dim = visual_vocab_size - len(INDICATOR_IDS) head_dim = visual_vocab_size - len(INDICATOR_IDS)
...@@ -133,18 +130,16 @@ class VisualTokenizer(torch.nn.Module): ...@@ -133,18 +130,16 @@ class VisualTokenizer(torch.nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
): ):
model_type = config.model_type model_type = config.model_type
if model_type == "siglip2_navit": if model_type == "siglip2_navit":
return Siglip2NavitModel( return Siglip2NavitModel(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=prefix, prefix=prefix,
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
) )
raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}") raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
...@@ -468,17 +463,12 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -468,17 +463,12 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
prefix=maybe_prefix(prefix, "llm"), prefix=maybe_prefix(prefix, "llm"),
) )
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual_tokenizer = VisualTokenizer( self.visual_tokenizer = VisualTokenizer(
config=config.vit_config, config=config.vit_config,
visual_vocab_size=config.visual_vocab_size, visual_vocab_size=config.visual_vocab_size,
multimodal_config=multimodal_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer", prefix=f"{prefix}.visual_tokenizer",
attn_backend_override=attn_backend_override,
) )
self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size) self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)
......
...@@ -22,7 +22,6 @@ from typing import Annotated, Literal ...@@ -22,7 +22,6 @@ from typing import Annotated, Literal
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from transformers.activations import GELUActivation from transformers.activations import GELUActivation
...@@ -32,13 +31,10 @@ from transformers.modeling_outputs import ( ...@@ -32,13 +31,10 @@ from transformers.modeling_outputs import (
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layers.mm_encoder_attention import (
maybe_get_vit_flash_attn_backend, MMEncoderAttention,
) )
from vllm.attention.ops.vit_attn_wrappers import ( from vllm.config import MultiModalConfig, VllmConfig
vit_flash_attn_wrapper,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -578,9 +574,8 @@ class SiglipAttention(nn.Module): ...@@ -578,9 +574,8 @@ class SiglipAttention(nn.Module):
num_heads: int, num_heads: int,
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -608,18 +603,12 @@ class SiglipAttention(nn.Module): ...@@ -608,18 +603,12 @@ class SiglipAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
) )
self.attn = MMEncoderAttention(
self.attn_backend = attn_backend num_heads=self.num_attention_heads_per_partition,
self.attn_backend, self.flash_attn_varlen_func = ( head_size=self.hidden_size_per_attention_head,
maybe_get_vit_flash_attn_backend( multimodal_config=multimodal_config,
self.attn_backend, prefix=f"{prefix}.attn",
attn_backend_override=attn_backend_override,
)
) )
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
seq_len, bs, _ = qkv.shape seq_len, bs, _ = qkv.shape
...@@ -665,44 +654,16 @@ class SiglipAttention(nn.Module): ...@@ -665,44 +654,16 @@ class SiglipAttention(nn.Module):
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend: context_layer = self.attn(
if max_seqlen is None: query=q,
raise ValueError("Flash attention backend requires max_seqlen.") key=k,
context_layer = vit_flash_attn_wrapper( value=v,
q, cu_seqlens=cu_seqlens,
k, max_seqlen=max_seqlen,
v,
cu_seqlens,
max_seqlen,
batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(tensor, "b s h d -> b h s d")
for tensor in (q_i, k_i, v_i)
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
else:
raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
) )
context_layer = rearrange(context_layer, "b s h d -> b s (h d)")
output, _ = self.out_proj(context_layer) output, _ = self.out_proj(context_layer)
output = rearrange(output, "s b d -> b s d")
return output return output
...@@ -774,10 +735,8 @@ class SiglipEncoderLayer(nn.Module): ...@@ -774,10 +735,8 @@ class SiglipEncoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
*,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -787,9 +746,8 @@ class SiglipEncoderLayer(nn.Module): ...@@ -787,9 +746,8 @@ class SiglipEncoderLayer(nn.Module):
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
projection_size=config.hidden_size, projection_size=config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
attn_backend=attn_backend,
attn_backend_override=attn_backend_override,
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
...@@ -832,14 +790,18 @@ class SiglipEncoder(nn.Module): ...@@ -832,14 +790,18 @@ class SiglipEncoder(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
...@@ -858,9 +820,8 @@ class SiglipEncoder(nn.Module): ...@@ -858,9 +820,8 @@ class SiglipEncoder(nn.Module):
SiglipEncoderLayer( SiglipEncoderLayer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override,
) )
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
] ]
...@@ -941,8 +902,8 @@ class SiglipVisionTransformer(nn.Module): ...@@ -941,8 +902,8 @@ class SiglipVisionTransformer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -952,8 +913,8 @@ class SiglipVisionTransformer(nn.Module): ...@@ -952,8 +913,8 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder( self.encoder = SiglipEncoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_backend_override=attn_backend_override,
) )
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
...@@ -991,16 +952,16 @@ class SiglipVisionModel(nn.Module): ...@@ -991,16 +952,16 @@ class SiglipVisionModel(nn.Module):
self, self,
config, config,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.vision_model = SiglipVisionTransformer( self.vision_model = SiglipVisionTransformer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
attn_backend_override=attn_backend_override,
) )
self.quant_config = quant_config self.quant_config = quant_config
...@@ -1119,17 +1080,11 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -1119,17 +1080,11 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = SiglipVisionModel( self.visual = SiglipVisionModel(
config=config.vision_config, config=config.vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
) )
self.mlp_AR = Projector(config, config.vision_config) self.mlp_AR = Projector(config, config.vision_config)
......
...@@ -845,6 +845,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -845,6 +845,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
) )
else: else:
self.visual = None self.visual = None
......
...@@ -42,13 +42,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( ...@@ -42,13 +42,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
) )
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import MultiModalConfig, VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
...@@ -267,10 +263,15 @@ class Qwen2_5_VisionMLP(nn.Module): ...@@ -267,10 +263,15 @@ class Qwen2_5_VisionMLP(nn.Module):
bias: bool = False, bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features, input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
...@@ -304,13 +305,16 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -304,13 +305,16 @@ class Qwen2_5_VisionAttention(nn.Module):
num_heads: int, num_heads: int,
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = ( self.tp_size = (
1 1
if use_data_parallel if use_data_parallel
...@@ -342,18 +346,12 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -342,18 +346,12 @@ class Qwen2_5_VisionAttention(nn.Module):
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
disable_tp=use_data_parallel, disable_tp=use_data_parallel,
) )
self.attn_backend = attn_backend
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
self.is_flash_attn_backend = self.attn_backend in { self.attn = MMEncoderAttention(
AttentionBackendEnum.FLASH_ATTN, num_heads=self.num_attention_heads_per_partition,
AttentionBackendEnum.ROCM_AITER_FA, head_size=self.hidden_size_per_attention_head,
} multimodal_config=multimodal_config,
)
def forward( def forward(
self, self,
...@@ -394,32 +392,17 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -394,32 +392,17 @@ class Qwen2_5_VisionAttention(nn.Module):
else: else:
q, k, v = qkv.unbind(dim=2) q, k, v = qkv.unbind(dim=2)
if self.is_flash_attn_backend: context_layer = self.attn(
context_layer = vit_flash_attn_wrapper( query=q,
q, key=k,
k, value=v,
v, cu_seqlens=cu_seqlens,
cu_seqlens, max_seqlen=max_seqlen,
max_seqlen,
batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
) )
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
# Never remove the next contiguous logic context_layer = einops.rearrange(
# Without it, hallucinations occur with the backend context_layer, "b s h d -> s b (h d)", b=batch_size
if current_platform.is_rocm(): ).contiguous()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
context_layer = vit_torch_sdpa_wrapper(
q,
k,
v,
cu_seqlens,
)
output, _ = self.proj(context_layer) output, _ = self.proj(context_layer)
return output return output
...@@ -443,10 +426,8 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -443,10 +426,8 @@ class Qwen2_5_VisionBlock(nn.Module):
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
...@@ -458,10 +439,8 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -458,10 +439,8 @@ class Qwen2_5_VisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
attn_backend_override=attn_backend_override,
) )
self.mlp = Qwen2_5_VisionMLP( self.mlp = Qwen2_5_VisionMLP(
dim, dim,
...@@ -469,8 +448,8 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -469,8 +448,8 @@ class Qwen2_5_VisionBlock(nn.Module):
act_fn=act_fn, act_fn=act_fn,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
def forward( def forward(
...@@ -542,10 +521,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module): ...@@ -542,10 +521,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None: if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6) norm_layer = partial(nn.LayerNorm, eps=1e-6)
...@@ -586,9 +570,8 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -586,9 +570,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
vision_config: Qwen2_5_VLVisionConfig, vision_config: Qwen2_5_VLVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -598,7 +581,6 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -598,7 +581,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
depth = vision_config.depth depth = vision_config.depth
self.hidden_size = vision_config.hidden_size self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads self.num_heads = vision_config.num_heads
self.use_data_parallel = use_data_parallel
self.out_hidden_size = vision_config.out_hidden_size self.out_hidden_size = vision_config.out_hidden_size
# args for get_window_index_thw # args for get_window_index_thw
...@@ -629,19 +611,17 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -629,19 +611,17 @@ class Qwen2_5_VisionTransformer(nn.Module):
rope_parameters={"partial_rotary_factor": 0.5}, rope_parameters={"partial_rotary_factor": 0.5},
) )
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in { if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
...@@ -661,10 +641,8 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -661,10 +641,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
act_fn=get_act_and_mul_fn(vision_config.hidden_act), act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override,
) )
for layer_idx in range(depth) for layer_idx in range(depth)
] ]
...@@ -677,8 +655,8 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -677,8 +655,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size, spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger", prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
) )
@property @property
...@@ -1200,18 +1178,12 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1200,18 +1178,12 @@ class Qwen2_5_VLForConditionalGeneration(
if multimodal_config.get_limit_per_prompt( if multimodal_config.get_limit_per_prompt(
"image" "image"
) or multimodal_config.get_limit_per_prompt("video"): ) or multimodal_config.get_limit_per_prompt("video"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen2_5_VisionTransformer( self.visual = Qwen2_5_VisionTransformer(
vision_config=config.vision_config, vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel, multimodal_config=multimodal_config,
attn_backend_override=attn_backend_override,
) )
else: else:
self.visual = None self.visual = None
......
...@@ -33,7 +33,6 @@ from typing import Annotated, Any, Literal, TypeAlias ...@@ -33,7 +33,6 @@ from typing import Annotated, Any, Literal, TypeAlias
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import BatchFeature from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
...@@ -45,10 +44,8 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize ...@@ -45,10 +44,8 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
maybe_get_vit_flash_attn_backend, from vllm.config import MultiModalConfig, VllmConfig
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -251,10 +248,15 @@ class Qwen2VisionMLP(nn.Module): ...@@ -251,10 +248,15 @@ class Qwen2VisionMLP(nn.Module):
hidden_features: int, hidden_features: int,
act_layer: type[nn.Module] = QuickGELU, act_layer: type[nn.Module] = QuickGELU,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
in_features, in_features,
hidden_features, hidden_features,
...@@ -295,12 +297,16 @@ class Qwen2VisionAttention(nn.Module): ...@@ -295,12 +297,16 @@ class Qwen2VisionAttention(nn.Module):
num_heads: int, num_heads: int,
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = ( self.tp_size = (
1 1
if use_data_parallel if use_data_parallel
...@@ -329,34 +335,12 @@ class Qwen2VisionAttention(nn.Module): ...@@ -329,34 +335,12 @@ class Qwen2VisionAttention(nn.Module):
disable_tp=use_data_parallel, disable_tp=use_data_parallel,
) )
# Detect attention implementation. self.attn = MMEncoderAttention(
self.attn_backend = get_vit_attn_backend( num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(), multimodal_config=multimodal_config,
attn_backend_override=attn_backend_override,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim] # [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape seq_len, bs, _ = qkv.shape
...@@ -398,7 +382,6 @@ class Qwen2VisionAttention(nn.Module): ...@@ -398,7 +382,6 @@ class Qwen2VisionAttention(nn.Module):
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x) q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
...@@ -409,49 +392,15 @@ class Qwen2VisionAttention(nn.Module): ...@@ -409,49 +392,15 @@ class Qwen2VisionAttention(nn.Module):
) )
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend: context_layer = self.attn(
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) query=q,
key=k,
output = self.flash_attn_varlen_func( value=v,
q, cu_seqlens=cu_seqlens,
k, max_seqlen=max_seqlen,
v, )
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens, context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer) output, _ = self.proj(context_layer)
return output return output
...@@ -466,9 +415,8 @@ class Qwen2VisionBlock(nn.Module): ...@@ -466,9 +415,8 @@ class Qwen2VisionBlock(nn.Module):
act_layer: type[nn.Module] = QuickGELU, act_layer: type[nn.Module] = QuickGELU,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
...@@ -482,17 +430,16 @@ class Qwen2VisionBlock(nn.Module): ...@@ -482,17 +430,16 @@ class Qwen2VisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
) )
self.mlp = Qwen2VisionMLP( self.mlp = Qwen2VisionMLP(
dim, dim,
mlp_hidden_dim, mlp_hidden_dim,
act_layer=act_layer, act_layer=act_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
def forward( def forward(
...@@ -552,10 +499,15 @@ class Qwen2VisionPatchMerger(nn.Module): ...@@ -552,10 +499,15 @@ class Qwen2VisionPatchMerger(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None: if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6) norm_layer = partial(nn.LayerNorm, eps=1e-6)
...@@ -599,9 +551,8 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -599,9 +551,8 @@ class Qwen2VisionTransformer(nn.Module):
vision_config: Qwen2VLVisionConfig, vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -615,7 +566,11 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -615,7 +566,11 @@ class Qwen2VisionTransformer(nn.Module):
num_heads = vision_config.num_heads num_heads = vision_config.num_heads
mlp_ratio = vision_config.mlp_ratio mlp_ratio = vision_config.mlp_ratio
self.use_data_parallel = use_data_parallel self.use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.out_hidden_size = vision_config.hidden_size self.out_hidden_size = vision_config.hidden_size
self.spatial_merge_size = spatial_merge_size self.spatial_merge_size = spatial_merge_size
...@@ -647,8 +602,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -647,8 +602,7 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel, multimodal_config=multimodal_config,
attn_backend_override=attn_backend_override,
) )
for layer_idx in range(depth) for layer_idx in range(depth)
] ]
...@@ -659,7 +613,10 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -659,7 +613,10 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.merger", prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel, multimodal_config=multimodal_config,
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
) )
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
...@@ -720,7 +677,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -720,7 +677,7 @@ class Qwen2VisionTransformer(nn.Module):
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
def forward( def forward(
...@@ -1324,18 +1281,12 @@ class Qwen2VLForConditionalGeneration( ...@@ -1324,18 +1281,12 @@ class Qwen2VLForConditionalGeneration(
if multimodal_config.get_limit_per_prompt( if multimodal_config.get_limit_per_prompt(
"image" "image"
) or multimodal_config.get_limit_per_prompt("video"): ) or multimodal_config.get_limit_per_prompt("video"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen2VisionTransformer( self.visual = Qwen2VisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
) )
else: else:
self.visual = None self.visual = None
......
...@@ -48,7 +48,7 @@ from transformers.models.whisper import WhisperFeatureExtractor ...@@ -48,7 +48,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import MultiModalConfig, VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
...@@ -192,6 +192,7 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -192,6 +192,7 @@ class Qwen3_VisionBlock(nn.Module):
mlp_hidden_dim: int, mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
...@@ -205,6 +206,7 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -205,6 +206,7 @@ class Qwen3_VisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
self.mlp = Qwen3_VisionMLP( self.mlp = Qwen3_VisionMLP(
...@@ -299,8 +301,8 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -299,8 +301,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
vision_config, vision_config,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = vision_config.hidden_size self.hidden_size = vision_config.hidden_size
...@@ -347,6 +349,7 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -347,6 +349,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
) )
for layer_idx in range(vision_config.depth) for layer_idx in range(vision_config.depth)
...@@ -376,6 +379,12 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -376,6 +379,12 @@ class Qwen3Omni_VisionTransformer(nn.Module):
] ]
) )
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
...@@ -1188,17 +1197,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1188,17 +1197,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config) self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3Omni_VisionTransformer( self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config, vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override, multimodal_config=multimodal_config,
) )
self.quant_config = quant_config self.quant_config = quant_config
......
...@@ -50,7 +50,7 @@ from transformers.video_utils import VideoMetadata ...@@ -50,7 +50,7 @@ from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -169,10 +169,15 @@ class Qwen3_VisionMLP(nn.Module): ...@@ -169,10 +169,15 @@ class Qwen3_VisionMLP(nn.Module):
bias: bool = False, bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.linear_fc1 = ColumnParallelLinear( self.linear_fc1 = ColumnParallelLinear(
in_features, in_features,
hidden_features, hidden_features,
...@@ -206,10 +211,9 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -206,10 +211,9 @@ class Qwen3_VisionBlock(nn.Module):
mlp_hidden_dim: int, mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
...@@ -221,9 +225,8 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -221,9 +225,8 @@ class Qwen3_VisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
) )
self.mlp = Qwen3_VisionMLP( self.mlp = Qwen3_VisionMLP(
dim, dim,
...@@ -231,8 +234,8 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -231,8 +234,8 @@ class Qwen3_VisionBlock(nn.Module):
act_fn=act_fn, act_fn=act_fn,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
def forward( def forward(
...@@ -264,10 +267,15 @@ class Qwen3_VisionPatchMerger(nn.Module): ...@@ -264,10 +267,15 @@ class Qwen3_VisionPatchMerger(nn.Module):
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
use_postshuffle_norm: bool = False, use_postshuffle_norm: bool = False,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm self.use_postshuffle_norm = use_postshuffle_norm
...@@ -313,9 +321,8 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -313,9 +321,8 @@ class Qwen3_VisionTransformer(nn.Module):
vision_config: Qwen3VLVisionConfig, vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = vision_config.hidden_size self.hidden_size = vision_config.hidden_size
...@@ -326,7 +333,6 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -326,7 +333,6 @@ class Qwen3_VisionTransformer(nn.Module):
self.spatial_merge_unit = self.spatial_merge_size**2 self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.use_data_parallel = use_data_parallel
self.num_grid_per_side = int(self.num_position_embeddings**0.5) self.num_grid_per_side = int(self.num_position_embeddings**0.5)
# NOTE: This is used for creating empty tensor for all_gather for # NOTE: This is used for creating empty tensor for all_gather for
...@@ -359,8 +365,8 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -359,8 +365,8 @@ class Qwen3_VisionTransformer(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size, spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger", prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
) )
self.deepstack_merger_list = nn.ModuleList( self.deepstack_merger_list = nn.ModuleList(
...@@ -372,13 +378,16 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -372,13 +378,16 @@ class Qwen3_VisionTransformer(nn.Module):
use_postshuffle_norm=True, use_postshuffle_norm=True,
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
use_data_parallel=use_data_parallel,
) )
for layer_idx in range(len(self.deepstack_visual_indexes)) for layer_idx in range(len(self.deepstack_visual_indexes))
] ]
) )
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
...@@ -402,9 +411,8 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -402,9 +411,8 @@ class Qwen3_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
) )
for layer_idx in range(vision_config.depth) for layer_idx in range(vision_config.depth)
] ]
...@@ -1277,18 +1285,12 @@ class Qwen3VLForConditionalGeneration( ...@@ -1277,18 +1285,12 @@ class Qwen3VLForConditionalGeneration(
) and not multimodal_config.get_limit_per_prompt("video"): ) and not multimodal_config.get_limit_per_prompt("video"):
self.visual = None self.visual = None
else: else:
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3_VisionTransformer( self.visual = Qwen3_VisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
) )
self.language_model = Qwen3LLMForCausalLM( self.language_model = Qwen3LLMForCausalLM(
......
...@@ -418,7 +418,6 @@ class Qwen3VLMoeForConditionalGeneration( ...@@ -418,7 +418,6 @@ class Qwen3VLMoeForConditionalGeneration(
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if not multimodal_config.get_limit_per_prompt( if not multimodal_config.get_limit_per_prompt(
"image" "image"
...@@ -429,8 +428,8 @@ class Qwen3VLMoeForConditionalGeneration( ...@@ -429,8 +428,8 @@ class Qwen3VLMoeForConditionalGeneration(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
) )
self.language_model = Qwen3MoeLLMForCausalLM( self.language_model = Qwen3MoeLLMForCausalLM(
......
...@@ -13,7 +13,8 @@ from transformers import Siglip2VisionConfig ...@@ -13,7 +13,8 @@ from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
...@@ -28,8 +29,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -28,8 +29,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .vision import get_vit_attn_backend
class VisionRotaryEmbedding(nn.Module): class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None: def __init__(self, dim: int, theta: float = 10000.0) -> None:
...@@ -190,7 +189,7 @@ def apply_rotary_pos_emb( ...@@ -190,7 +189,7 @@ def apply_rotary_pos_emb(
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous() cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous()
if is_flash_attn_backend and not current_platform.is_xpu(): if is_flash_attn_backend and current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
apply_rotary_emb_func = apply_rotary_emb apply_rotary_emb_func = apply_rotary_emb
...@@ -208,6 +207,7 @@ class Siglip2Attention(nn.Module): ...@@ -208,6 +207,7 @@ class Siglip2Attention(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None, attn_backend_override: AttentionBackendEnum | None = None,
...@@ -227,20 +227,25 @@ class Siglip2Attention(nn.Module): ...@@ -227,20 +227,25 @@ class Siglip2Attention(nn.Module):
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
self.is_causal = False self.is_causal = False
# TODO(Isotr0py): Enable data parallel after we support use_data_parallel = (
# disabling TP on parallel linear layer multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim, hidden_size=self.embed_dim,
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.num_heads, total_num_heads=self.num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
input_size=self.embed_dim, input_size=self.embed_dim,
output_size=self.embed_dim, output_size=self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
) )
self.tp_size = ( self.tp_size = (
...@@ -249,31 +254,13 @@ class Siglip2Attention(nn.Module): ...@@ -249,31 +254,13 @@ class Siglip2Attention(nn.Module):
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.use_rope = config.use_rope self.use_rope = config.use_rope
# Detect attention implementation. self.attn = MMEncoderAttention(
self.attn_backend = get_vit_attn_backend( num_heads=self.num_heads_per_partition,
head_size=self.head_dim, head_size=self.head_dim,
dtype=torch.get_default_dtype(), prefix=f"{prefix}.attn",
attn_backend_override=attn_backend_override, multimodal_config=multimodal_config,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
) )
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -298,46 +285,23 @@ class Siglip2Attention(nn.Module): ...@@ -298,46 +285,23 @@ class Siglip2Attention(nn.Module):
keys.unsqueeze(0), keys.unsqueeze(0),
cos, cos,
sin, sin,
self.is_flash_attn_backend, self.attn.is_flash_attn_backend,
) )
queries = queries.squeeze(0) queries = queries.squeeze(0)
keys = keys.squeeze(0) keys = keys.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
if self.is_flash_attn_backend: attn_output = self.attn(
attn_output = self.flash_attn_varlen_func( query=queries.unsqueeze(0),
queries, key=keys.unsqueeze(0),
keys, value=values.unsqueeze(0),
values, cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens, max_seqlen=max_seqlen,
cu_seqlens_k=cu_seqlens, )
max_seqlen_q=max_seqlen, attn_output = attn_output.reshape(
max_seqlen_k=max_seqlen, seq_length, self.num_heads_per_partition * self.head_dim
).reshape(seq_length, -1) )
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
batch_size = cu_seqlens.shape[0] - 1
outputs = []
cu = cu_seqlens.tolist()
for i in range(batch_size):
start_idx = cu[i]
end_idx = cu[i + 1]
# Each sequence is processed independently.
q_i = queries[start_idx:end_idx].unsqueeze(0)
k_i = keys[start_idx:end_idx].unsqueeze(0)
v_i = values[start_idx:end_idx].unsqueeze(0)
# (1, seq_len, num_heads, head_dim) ->
# (1, num_heads, seq_len, head_dim)
q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1)
outputs.append(output_i)
attn_output = torch.cat(outputs, dim=0)
attn_output, _ = self.out_proj(attn_output) attn_output, _ = self.out_proj(attn_output)
return attn_output return attn_output
...@@ -347,25 +311,30 @@ class Siglip2MLP(nn.Module): ...@@ -347,25 +311,30 @@ class Siglip2MLP(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
# TODO(Isotr0py): Enable data parallel after we support
# disabling TP on parallel linear layer
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1", prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc2", prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -380,9 +349,8 @@ class Siglip2EncoderLayer(nn.Module): ...@@ -380,9 +349,8 @@ class Siglip2EncoderLayer(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -390,16 +358,15 @@ class Siglip2EncoderLayer(nn.Module): ...@@ -390,16 +358,15 @@ class Siglip2EncoderLayer(nn.Module):
self.self_attn = Siglip2Attention( self.self_attn = Siglip2Attention(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP( self.mlp = Siglip2MLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
def forward( def forward(
...@@ -444,9 +411,8 @@ class Siglip2Encoder(nn.Module): ...@@ -444,9 +411,8 @@ class Siglip2Encoder(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -455,9 +421,8 @@ class Siglip2Encoder(nn.Module): ...@@ -455,9 +421,8 @@ class Siglip2Encoder(nn.Module):
Siglip2EncoderLayer( Siglip2EncoderLayer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{idx}", prefix=f"{prefix}.layers.{idx}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
) )
for idx in range(config.num_hidden_layers) for idx in range(config.num_hidden_layers)
] ]
...@@ -630,9 +595,8 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -630,9 +595,8 @@ class Siglip2VisionTransformer(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -642,9 +606,8 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -642,9 +606,8 @@ class Siglip2VisionTransformer(nn.Module):
self.encoder = Siglip2Encoder( self.encoder = Siglip2Encoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
) )
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
...@@ -671,18 +634,16 @@ class Siglip2NavitModel(torch.nn.Module): ...@@ -671,18 +634,16 @@ class Siglip2NavitModel(torch.nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.vision_model = Siglip2VisionTransformer( self.vision_model = Siglip2VisionTransformer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
) )
def forward( def forward(
......
...@@ -88,14 +88,17 @@ def get_vit_attn_backend( ...@@ -88,14 +88,17 @@ def get_vit_attn_backend(
""" """
Get the available attention backend for Vision Transformer. Get the available attention backend for Vision Transformer.
""" """
if attn_backend_override is not None: attn_backend = attn_backend_override
return attn_backend_override
selected_backend = get_current_vllm_config().attention_config.backend selected_backend = get_current_vllm_config().attention_config.backend
if selected_backend is not None: if attn_backend is None:
return selected_backend attn_backend = selected_backend
return current_platform.get_vit_attn_backend(head_size, dtype) return current_platform.get_vit_attn_backend(
head_size,
dtype,
backend=attn_backend,
)
def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool: def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:
......
...@@ -7,7 +7,7 @@ pynvml. However, it should not initialize cuda context. ...@@ -7,7 +7,7 @@ pynvml. However, it should not initialize cuda context.
import os import os
from collections.abc import Callable from collections.abc import Callable
from functools import cache, wraps from functools import cache, wraps
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, Optional, TypeVar
import torch import torch
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
...@@ -255,23 +255,6 @@ class CudaPlatformBase(Platform): ...@@ -255,23 +255,6 @@ class CudaPlatformBase(Platform):
torch.cuda.reset_peak_memory_stats(device) torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device) return torch.cuda.max_memory_allocated(device)
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
# Try FlashAttention first
if (cc := cls.get_device_capability()) and cc.major >= 8:
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
except ImportError:
pass
return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
def get_valid_backends( def get_valid_backends(
cls, cls,
...@@ -418,6 +401,41 @@ class CudaPlatformBase(Platform): ...@@ -418,6 +401,41 @@ class CudaPlatformBase(Platform):
return selected_backend.get_path() return selected_backend.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASH_ATTN,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
# Try FlashAttention first
if (cc := cls.get_device_capability()) and cc.major >= 8:
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
except ImportError:
pass
return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
def get_punica_wrapper(cls) -> str: def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment