Unverified Commit 007dd908 authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[gpt-oss] Enable gpt-oss on ampere (#22714)


Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
parent b8a9d0e4
...@@ -25,5 +25,6 @@ class DummyPlatform(Platform): ...@@ -25,5 +25,6 @@ class DummyPlatform(Platform):
compilation_config.custom_ops = ["all"] compilation_config.custom_ops = ["all"]
def get_attn_backend_cls(self, backend_name, head_size, dtype, def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla): kv_cache_dtype, block_size, use_v1, use_mla,
has_sink):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
...@@ -138,6 +138,7 @@ class Attention(nn.Module): ...@@ -138,6 +138,7 @@ class Attention(nn.Module):
self.head_size = head_size self.head_size = head_size
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
quant_method = quant_config.get_quant_method( quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None self, prefix=prefix) if quant_config else None
...@@ -165,7 +166,8 @@ class Attention(nn.Module): ...@@ -165,7 +166,8 @@ class Attention(nn.Module):
kv_cache_dtype, kv_cache_dtype,
block_size, block_size,
is_attention_free, is_attention_free,
use_mla=use_mla) use_mla=use_mla,
has_sink=self.has_sink)
else: else:
self.attn_backend = attn_backend self.attn_backend = attn_backend
......
...@@ -144,6 +144,7 @@ def get_attn_backend( ...@@ -144,6 +144,7 @@ def get_attn_backend(
block_size: int, block_size: int,
is_attention_free: bool = False, is_attention_free: bool = False,
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong # Accessing envs.* behind an @lru_cache decorator can cause the wrong
...@@ -158,6 +159,7 @@ def get_attn_backend( ...@@ -158,6 +159,7 @@ def get_attn_backend(
is_attention_free=is_attention_free, is_attention_free=is_attention_free,
use_v1=envs.VLLM_USE_V1, use_v1=envs.VLLM_USE_V1,
use_mla=use_mla, use_mla=use_mla,
has_sink=has_sink,
) )
...@@ -170,6 +172,7 @@ def _cached_get_attn_backend( ...@@ -170,6 +172,7 @@ def _cached_get_attn_backend(
is_attention_free: bool, is_attention_free: bool,
use_v1: bool = False, use_v1: bool = False,
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
# If there are no attention layers (e.g. we are running Mamba), # If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION # use the placeholder NO_ATTENTION
...@@ -201,7 +204,7 @@ def _cached_get_attn_backend( ...@@ -201,7 +204,7 @@ def _cached_get_attn_backend(
# get device-specific attn_backend # get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls( attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
use_mla) use_mla, has_sink)
if not attention_cls: if not attention_cls:
raise ValueError( raise ValueError(
f"Invalid attention backend for {current_platform.device_name}") f"Invalid attention backend for {current_platform.device_name}")
......
...@@ -42,7 +42,7 @@ class Mxfp4Config(QuantizationConfig): ...@@ -42,7 +42,7 @@ class Mxfp4Config(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 90 return 80
@classmethod @classmethod
def get_name(cls) -> QuantizationMethods: def get_name(cls) -> QuantizationMethods:
......
...@@ -91,8 +91,8 @@ class CpuPlatform(Platform): ...@@ -91,8 +91,8 @@ class CpuPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, block_size: int, use_v1: bool, use_mla: bool,
use_mla: bool) -> str: has_sink: bool) -> str:
if selected_backend and selected_backend != _Backend.TORCH_SDPA: if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if use_mla:
......
...@@ -222,8 +222,8 @@ class CudaPlatformBase(Platform): ...@@ -222,8 +222,8 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, kv_cache_dtype, block_size, use_v1, use_mla,
use_mla) -> str: has_sink) -> str:
if use_mla: if use_mla:
# TODO(lucas): refactor to be more concise # TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here # we should probably consider factoring out V1 here
...@@ -321,6 +321,9 @@ class CudaPlatformBase(Platform): ...@@ -321,6 +321,9 @@ class CudaPlatformBase(Platform):
# FlashAttention is the default for SM 8.0+ GPUs # FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80): if cls.has_device_capability(80):
if has_sink:
logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1
if is_default_backend_supported := is_attn_backend_supported( if is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype, FLASH_ATTN_V1, head_size, dtype,
allow_import_error=False): allow_import_error=False):
......
...@@ -196,8 +196,8 @@ class Platform: ...@@ -196,8 +196,8 @@ class Platform:
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, block_size: int, use_v1: bool, use_mla: bool,
use_mla: bool) -> str: has_sink: bool) -> str:
"""Get the attention backend class of a device.""" """Get the attention backend class of a device."""
return "" return ""
......
...@@ -188,8 +188,8 @@ class RocmPlatform(Platform): ...@@ -188,8 +188,8 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, kv_cache_dtype, block_size, use_v1, use_mla,
use_mla) -> str: has_sink) -> str:
if use_mla: if use_mla:
from vllm.attention.backends.rocm_aiter_mla import ( from vllm.attention.backends.rocm_aiter_mla import (
is_aiter_mla_enabled) is_aiter_mla_enabled)
......
...@@ -46,8 +46,8 @@ class TpuPlatform(Platform): ...@@ -46,8 +46,8 @@ class TpuPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, block_size: int, use_v1: bool, use_mla: bool,
use_mla: bool) -> str: has_sink) -> str:
if (selected_backend != _Backend.PALLAS if (selected_backend != _Backend.PALLAS
and selected_backend != _Backend.PALLAS_VLLM_V1): and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Cannot use %s backend on TPU.", selected_backend)
......
...@@ -35,8 +35,8 @@ class XPUPlatform(Platform): ...@@ -35,8 +35,8 @@ class XPUPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, block_size: int, use_v1: bool, use_mla: bool,
use_mla: bool) -> str: has_sink: bool) -> str:
if selected_backend is not None and selected_backend != _Backend.IPEX: if selected_backend is not None and selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend) logger.info("Cannot use %s backend on XPU.", selected_backend)
use_v1 = envs.VLLM_USE_V1 use_v1 = envs.VLLM_USE_V1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment