Unverified Commit 87b4d155 authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend...


[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend of QwenVisionAttention with it. (#30125)
Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
parent 84e23d10
......@@ -7,7 +7,7 @@ import platform
import random
import sys
from datetime import timedelta
from typing import TYPE_CHECKING, Any, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
import numpy as np
import torch
......@@ -222,12 +222,6 @@ class Platform:
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_attn_backend_cls(
cls,
......@@ -245,6 +239,43 @@ class Platform:
"""Get the attention backend class of a device."""
return ""
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
"""
Get the vision attention backend class of a device.
NOTE: ViT Attention should be checked and override in the platform-specific
implementation. we should not override this in any other places, like
the model_executor/models/<model_name>.py.
We check if the backend is None or not:
1. If not, check if the backend is supported by the platform.
2. If None, continue to the default selection logic.
"""
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
logger.info_once(
f"Using default backend {AttentionBackendEnum.TORCH_SDPA} for vit attention"
)
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_device_capability(
cls,
......
......@@ -3,7 +3,7 @@
import os
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
......@@ -187,24 +187,6 @@ class RocmPlatform(Platform):
if not on_gfx9():
supported_quantization += ["bitsandbytes"]
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> AttentionBackendEnum:
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return AttentionBackendEnum.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
return AttentionBackendEnum.FLASH_ATTN
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_attn_backend_cls(
cls,
......@@ -322,6 +304,43 @@ class RocmPlatform(Platform):
"ROCm. Note that V0 attention backends have been removed."
)
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TORCH_SDPA,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return AttentionBackendEnum.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
return AttentionBackendEnum.FLASH_ATTN
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Optional, cast
import torch
from tpu_info import device
......@@ -75,6 +75,32 @@ class TpuPlatform(Platform):
logger.info("Using Pallas V1 backend.")
return AttentionBackendEnum.PALLAS.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.PALLAS,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
)
logger.info_once(f"Using backend {backend} for vit attention.")
return backend
logger.info_once(
f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
)
return AttentionBackendEnum.PALLAS
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
......
......@@ -3,7 +3,7 @@
import contextlib
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
......@@ -77,6 +77,34 @@ class XPUPlatform(Platform):
logger.info("Using Flash Attention backend.")
return AttentionBackendEnum.FLASH_ATTN.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
# XPU only supports FLASH_ATTN for vision attention.
return [
AttentionBackendEnum.FLASH_ATTN,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: "
f"{cls.get_supported_vit_attn_backends()}."
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
logger.info_once(
f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention"
)
return AttentionBackendEnum.FLASH_ATTN
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
......@@ -110,12 +138,6 @@ class XPUPlatform(Platform):
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
return AttentionBackendEnum.FLASH_ATTN
@classmethod
def inference_mode(cls):
return torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment